diff --git a/examples/seld_spatialsoundqa/README.md b/examples/seld_spatialsoundqa/README.md new file mode 100644 index 00000000..5e3e487b --- /dev/null +++ b/examples/seld_spatialsoundqa/README.md @@ -0,0 +1,43 @@ +# SELD_SpatialSoundQA SELD_SpatialSoundQA + +This repo hosts the code and models of "[BAT: Learning to Reason about Spatial Sounds with Large Language Models](https://arxiv.org/abs/2402.01591)" [ICML 2024 [bib](https://github.com/zszheng147/Spatial-AST#citation)]. + +## Performance and checkpoints +Encoder | Projector | PEFT | LLM +|---|---|---|---| +[Spatial-AST](https://huggingface.co/zhisheng01/Bat/blob/main/spatial-ast.pth) | Q-Former | adapter |[llama-2-7b](https://huggingface.co/meta-llama/Llama-2-7b) + +## Data preparation +You need to prepare the data jsonl in this format. +``` +{"audio_id": "eval/audio/YI-HlrcP6Qg4", "reverb_id": "q9vSo1VnCiC/0.npy", "audio_id2": null, "reverb_id2": null, "question_id": 0, "question_type": "CLASSIFICATION", "question": "Enumerate the sound occurrences in the audio clip.", "answer": "accelerating, revving, vroom; car; vehicle"} +... +{"audio_id": "eval/audio/YZX2fVPmUidA", "reverb_id": "q9vSo1VnCiC/32.npy", "audio_id2": "eval/audio/YjNjUU01quLs", "reverb_id2": "q9vSo1VnCiC/31.npy", "question_id": 58, "question_type": "MIXUP_NONBINARY_DISTANCE", "question": "How far away is the sound of the banjo from the sound of the whack, thwack?", "answer": "2m"} +``` + +## Train a new model +```bash +bash examples/seld_spatialsoundqa/scripts/finetune_spatial-ast_linear_llama_2_7b.sh +``` + +## Decoding with checkpoints +```bash +bash examples/seld_spatialsoundqa/scripts/decode_spatial-ast_linear_llama_2_7b.sh +``` + + +## TODO +- [x] Decode with checkpoints +- [ ] Upload SpatialSoundQA dataset +- [ ] Upload pretrained checkpoints +- [ ] Update model performance + +## Citation +``` +@article{zheng2024bat, + author = {Zheng, Zhisheng and Peng, Puyuan and Ma, Ziyang and Chen, Xie and Choi, Eunsol and Harwath, David}, + title = {BAT: Learning to Reason about Spatial Sounds with Large Language Models}, + journal = {arXiv preprint arXiv:2402.01591}, + year = {2024}, +} +``` \ No newline at end of file diff --git a/examples/seld_spatialsoundqa/assets/bat.png b/examples/seld_spatialsoundqa/assets/bat.png new file mode 100644 index 00000000..945499d0 Binary files /dev/null and b/examples/seld_spatialsoundqa/assets/bat.png differ diff --git a/examples/seld_spatialsoundqa/conf/ds_config.json b/examples/seld_spatialsoundqa/conf/ds_config.json new file mode 100644 index 00000000..7ea70e4a --- /dev/null +++ b/examples/seld_spatialsoundqa/conf/ds_config.json @@ -0,0 +1,19 @@ +{ + "train_micro_batch_size_per_gpu": 4, + "gradient_accumulation_steps": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-4 + } + }, + "fp16": { + "enabled": true + }, + "zero_optimization": { + "stage": 3, + "offload_optimizer": { + "device": "cpu" + } + } +} \ No newline at end of file diff --git a/examples/seld_spatialsoundqa/finetune_seld.py b/examples/seld_spatialsoundqa/finetune_seld.py new file mode 100644 index 00000000..f719c276 --- /dev/null +++ b/examples/seld_spatialsoundqa/finetune_seld.py @@ -0,0 +1,45 @@ +import hydra +import logging +from dataclasses import dataclass, field +from omegaconf import DictConfig, ListConfig, OmegaConf + +from seld_config import ModelConfig, TrainConfig, DataConfig, LogConfig, FSDPConfig, PeftConfig +from slam_llm.pipeline.finetune import main as train + +@dataclass +class RunConfig: + dataset_config: DataConfig = field(default_factory=DataConfig) + model_config: ModelConfig = field(default_factory=ModelConfig) + train_config: TrainConfig = field(default_factory=TrainConfig) + log_config: LogConfig = field(default_factory=LogConfig) + fsdp_config: FSDPConfig = field(default_factory=FSDPConfig) + peft_config: PeftConfig = field(default_factory=PeftConfig) + debug: bool = field(default=False, metadata={"help": "Use pdb when true"}) + metric: str = field(default="acc", metadata={"help": "The metric for evaluation"}) + +@hydra.main(config_name=None, version_base=None) +def main_hydra(cfg: DictConfig): + run_config = RunConfig() + cfg = OmegaConf.merge(run_config, cfg) + def to_plain_list(cfg_item): + if isinstance(cfg_item, ListConfig): + return OmegaConf.to_container(cfg_item, resolve=True) + elif isinstance(cfg_item, DictConfig): + return {k: to_plain_list(v) for k, v in cfg_item.items()} + else: + return cfg_item + + # kwargs = to_plain_list(cfg) + kwargs = cfg + log_level = getattr(logging, kwargs.get("log_level", "INFO").upper()) + logging.basicConfig(level=log_level) + + if kwargs.get("debug", False): + import pdb; + pdb.set_trace() + + train(kwargs) + + +if __name__ == "__main__": + main_hydra() \ No newline at end of file diff --git a/examples/seld_spatialsoundqa/inference_seld_batch.py b/examples/seld_spatialsoundqa/inference_seld_batch.py new file mode 100644 index 00000000..4834d182 --- /dev/null +++ b/examples/seld_spatialsoundqa/inference_seld_batch.py @@ -0,0 +1,53 @@ +import hydra +import logging +from dataclasses import dataclass, field +from omegaconf import DictConfig, ListConfig, OmegaConf +from typing import Optional + +from slam_llm.pipeline.inference_batch import main as inference +from seld_config import ModelConfig, TrainConfig, DataConfig, LogConfig, FSDPConfig, PeftConfig + +@dataclass +class RunConfig: + dataset_config: DataConfig = field(default_factory=DataConfig) + model_config: ModelConfig = field(default_factory=ModelConfig) + train_config: TrainConfig = field(default_factory=TrainConfig) + log_config: LogConfig = field(default_factory=LogConfig) + fsdp_config: FSDPConfig = field(default_factory=FSDPConfig) + peft_config: PeftConfig = field(default_factory=PeftConfig) + debug: bool = field(default=False, metadata={"help": "Use pdb when true"}) + metric: str = field(default="acc", metadata={"help": "The metric for evaluation"}) + decode_log: str = field( + default="output/decode_log", + metadata={"help": "The prefix for the decode output"}, + ) + ckpt_path: str = field( + default="output/model.pt", metadata={"help": "The path to projector checkpoint"} + ) + peft_ckpt: Optional[str] = field( + default=None, + metadata={ + "help": "The path to peft checkpoint, should be a directory including adapter_config.json" + }, + ) + + +@hydra.main(config_name=None, version_base=None) +def main_hydra(cfg: DictConfig): + run_config = RunConfig() + cfg = OmegaConf.merge(run_config, cfg) + # kwargs = to_plain_list(cfg) + log_level = getattr(logging, cfg.get("log_level", "INFO").upper()) + + logging.basicConfig(level=log_level) + + if cfg.get("debug", False): + import pdb + + pdb.set_trace() + + inference(cfg) + + +if __name__ == "__main__": + main_hydra() diff --git a/examples/seld_spatialsoundqa/model/slam_model_seld.py b/examples/seld_spatialsoundqa/model/slam_model_seld.py new file mode 100644 index 00000000..935b5a3b --- /dev/null +++ b/examples/seld_spatialsoundqa/model/slam_model_seld.py @@ -0,0 +1,154 @@ +import torch +import os +import logging +from slam_llm.models.slam_model import ( + slam_model, + setup_tokenizer, + setup_encoder, + setup_encoder_projector, + setup_llm, +) +from slam_llm.utils.train_utils import print_model_size + +logger = logging.getLogger(__name__) + +def model_factory(train_config, model_config, **kwargs): + # return necessary components for training + tokenizer = setup_tokenizer(train_config, model_config, **kwargs) + + encoder = setup_encoder(train_config, model_config, **kwargs) + + # llm + llm = setup_llm(train_config, model_config, **kwargs) + + # projector + encoder_projector = setup_encoder_projector( + train_config, model_config, **kwargs + ) + model = slam_model_seld( + encoder, + llm, + encoder_projector, + tokenizer, + train_config, + model_config, + **kwargs, + ) + + ckpt_path = kwargs.get( + "ckpt_path", None + ) # FIX(MZY): load model ckpt(mainly projector, related to model_checkpointing/checkpoint_handler.py: save_model_checkpoint_peft) + if ckpt_path is not None: + logger.info("loading other parts from: {}".format(ckpt_path)) + ckpt_dict = torch.load(ckpt_path, map_location="cpu") + model.load_state_dict(ckpt_dict, strict=False) + + print_model_size( + model, + train_config, + ( + int(os.environ["RANK"]) + if train_config.enable_fsdp or train_config.enable_ddp + else 0 + ), + ) + return model, tokenizer + + +class slam_model_seld(slam_model): + def __init__( + self, + encoder, + llm, + encoder_projector, + tokenizer, + train_config, + model_config, + **kwargs, + ): + super().__init__( + encoder, + llm, + encoder_projector, + tokenizer, + train_config, + model_config, + **kwargs, + ) + + @torch.no_grad() + def inference( + self, + wav_path=None, + prompt=None, + generation_config=None, + logits_processor=None, + stopping_criteria=None, + prefix_allowed_tokens_fn=None, + synced_gpus=None, + assistant_model=None, + streamer=None, + negative_prompt_ids=None, + negative_prompt_attention_mask=None, + **kwargs, + ): + # inference for asr model + + device = kwargs.get("device", "cuda") + if os.path.exists(wav_path): # Audio-Text QA + import whisper + + audio_raw = whisper.load_audio(wav_path) + audio_raw = whisper.pad_or_trim(audio_raw) + + mel_size = getattr( + self.dataset_config, "mel_size", 80 + ) # 80 for large v1 and v2, 128 for large v3 + audio_mel = ( + whisper.log_mel_spectrogram(audio_raw, n_mels=mel_size) + .permute(1, 0)[None, :, :] + .to(device) + ) + + encoder_outs = self.encoder.extract_variable_length_features( + audio_mel.permute(0, 2, 1) + ) + + if self.model_config.encoder_projector == "q-former": + audio_mel_post_mask = torch.ones( + encoder_outs.size()[:-1], dtype=torch.long + ).to(encoder_outs.device) + encoder_outs = self.encoder_projector(encoder_outs, audio_mel_post_mask) + if self.model_config.encoder_projector == "linear": + encoder_outs = self.encoder_projector(encoder_outs) + else: # Text QA + encoder_outs = torch.empty( + 1, 0, self.llm.model.embed_tokens.embedding_dim + ).to(device) + + prompt = "USER: {}\n ASSISTANT:".format(prompt) + prompt_ids = self.tokenizer.encode(prompt) + prompt_length = len(prompt_ids) + prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64).to(device) + + if hasattr(self.llm.model, "embed_tokens"): + inputs_embeds = self.llm.model.embed_tokens(prompt_ids) + elif hasattr(self.llm.model.model, "embed_tokens"): + inputs_embeds = self.llm.model.model.embed_tokens(prompt_ids) + else: + inputs_embeds = self.llm.model.model.model.embed_tokens(prompt_ids) + + inputs_embeds = torch.cat( + (encoder_outs, inputs_embeds[None, :, :]), dim=1 + ) # [audio,prompt] + + attention_mask = torch.ones(inputs_embeds.size()[:-1], dtype=torch.long).to( + inputs_embeds.device + ) + + # generate + model_outputs = self.generate( + inputs_embeds=inputs_embeds, attention_mask=attention_mask, **kwargs + ) + + return model_outputs diff --git a/examples/seld_spatialsoundqa/scripts/decode_spatial-ast_linear_llama_2_7b.sh b/examples/seld_spatialsoundqa/scripts/decode_spatial-ast_linear_llama_2_7b.sh new file mode 100755 index 00000000..3cd7d119 --- /dev/null +++ b/examples/seld_spatialsoundqa/scripts/decode_spatial-ast_linear_llama_2_7b.sh @@ -0,0 +1,59 @@ +#!/bin/bash +#export PYTHONPATH=/root/whisper:$PYTHONPATH +export CUDA_VISIBLE_DEVICES=0 +export TOKENIZERS_PARALLELISM=false +# export CUDA_LAUNCH_BLOCKING=1 + +SLAM_DIR=/mnt/cloudstorfs/sjtu_home/zhisheng.zheng/SLAM-LLM +cd $SLAM_DIR +code_dir=examples/seld_spatialsoundqa + +stage=classification +qa_data_root=/mnt/cloudstorfs/sjtu_home/zhisheng.zheng/data/SpatialAudio/closed-end +reverb_data_root=/mnt/cloudstorfs/sjtu_home/zhisheng.zheng/data/SpatialAudio/reverb/mp3d +anechoic_data_root=/mnt/cloudstorfs/sjtu_home/zhisheng.zheng/data/AudioSet + +audio_encoder_path=/mnt/cloudstorfs/sjtu_home/zhisheng.zheng/models/SpatialAST/SpatialAST.pth +llm_path=/mnt/cloudstorfs/sjtu_home/zhisheng.zheng/models/llama-2-hf + +split=eval +# output_dir=/mnt/cloudstorfs/sjtu_home/zhisheng.zheng/SLAM-LLM/outputs/bat-vicuna-7b-v1.5-spatialAST-qformer-steplrwarmupkeep1e-4-${stage}-$(date +"%Y%m%d") +output_dir=/mnt/cloudstorfs/sjtu_home/zhisheng.zheng/SLAM-LLM/outputs/bat-llama-2-spatialAST-qformer-steplrwarmupkeep1e-4-classification-20240507 +ckpt_path=$output_dir/bat_epoch_2_step_2576 +decode_log=$ckpt_path/decode_${split}_beam4 + +# -m debugpy --listen 5678 --wait-for-client +python -u $code_dir/inference_seld_batch.py \ + --config-path "conf" \ + hydra.run.dir=$ckpt_path \ + ++model_config.llm_name=llama-2-7b \ + ++model_config.llm_path=$llm_path \ + ++model_config.llm_dim=4096 \ + ++model_config.encoder_name=SpatialAST \ + ++model_config.encoder_projector=q-former \ + ++model_config.encoder_ckpt=$audio_encoder_path \ + ++dataset_config.stage=$stage \ + ++dataset_config.qa_data_root=$qa_data_root \ + ++dataset_config.anechoic_data_root=$anechoic_data_root \ + ++dataset_config.reverb_data_root=$reverb_data_root \ + ++dataset_config.fix_length_audio=64 \ + ++dataset_config.inference_mode=true \ + ++train_config.model_name=bat \ + ++train_config.freeze_encoder=true \ + ++train_config.freeze_llm=true \ + ++train_config.batching_strategy=custom \ + ++train_config.num_epochs=1 \ + ++train_config.val_batch_size=8 \ + ++train_config.num_workers_dataloader=2 \ + ++train_config.output_dir=$output_dir \ + ++train_config.use_peft=true \ + ++peft_config.peft_method=llama_adapter \ + ++log_config.log_file=$output_dir/test.log \ + ++decode_log=$decode_log \ + ++ckpt_path=$ckpt_path/model.pt \ + # ++peft_ckpt=$ckpt_path \ + # ++train_config.use_peft=true \ + # ++train_config.peft_config.r=32 \ + # ++dataset_config.normalize=true \ + # ++model_config.encoder_projector=q-former \ + # ++dataset_config.fix_length_audio=64 \ diff --git a/examples/seld_spatialsoundqa/scripts/finetune_spatial-ast_linear_llama_2_7b.sh b/examples/seld_spatialsoundqa/scripts/finetune_spatial-ast_linear_llama_2_7b.sh new file mode 100755 index 00000000..1d9fdf9b --- /dev/null +++ b/examples/seld_spatialsoundqa/scripts/finetune_spatial-ast_linear_llama_2_7b.sh @@ -0,0 +1,75 @@ +#!/bin/bash +# export PYTHONPATH=/root/whisper:$PYTHONPATH +export CUDA_VISIBLE_DEVICES=0 #,1,2,3 +export TOKENIZERS_PARALLELISM=false +# export CUDA_LAUNCH_BLOCKING=1 +export OMP_NUM_THREADS=1 + +# debug setting for multiple gpus +# export NCCL_DEBUG=INFO +# export NCCL_DEBUG_SUBSYS=ALL +# export TORCH_DISTRIBUTED_DEBUG=INFO + +SLAM_DIR=/mnt/cloudstorfs/sjtu_home/zhisheng.zheng/downloads/SLAM-LLM +cd $SLAM_DIR +code_dir=examples/seld_spatialsoundqa + +audio_encoder_path=/mnt/cloudstorfs/sjtu_home/zhisheng.zheng/models/SpatialAST/SpatialAST.pth +llm_path=/mnt/cloudstorfs/sjtu_home/zhisheng.zheng/models/llama-2-hf + +stage=stage1-clsdoa +qa_data_root=/mnt/cloudstorfs/sjtu_home/zhisheng.zheng/data/SpatialAudio/closed-end +reverb_data_root=/mnt/cloudstorfs/sjtu_home/zhisheng.zheng/data/SpatialAudio/reverb/mp3d +anechoic_data_root=/mnt/cloudstorfs/sjtu_home/zhisheng.zheng/data/AudioSet + +output_dir=/mnt/cloudstorfs/sjtu_home/zhisheng.zheng/SLAM-LLM/outputs/bat-llama-2-spatialAST-qformer-steplrwarmupkeep1e-4-${stage}-$(date +"%Y%m%d") + +hydra_args=" +hydra.run.dir=$output_dir \ +++model_config.llm_name=llama-2-7b \ +++model_config.llm_path=$llm_path \ +++model_config.llm_dim=4096 \ +++model_config.encoder_name=SpatialAST \ +++model_config.encoder_projector=q-former \ +++model_config.encoder_ckpt=$audio_encoder_path \ +++dataset_config.stage=$stage \ +++dataset_config.qa_data_root=$qa_data_root \ +++dataset_config.anechoic_data_root=$anechoic_data_root \ +++dataset_config.reverb_data_root=$reverb_data_root \ +++dataset_config.fix_length_audio=64 \ +++train_config.model_name=bat \ +++train_config.num_epochs=2 \ +++train_config.freeze_encoder=true \ +++train_config.freeze_llm=true \ +++train_config.batching_strategy=custom \ +++train_config.warmup_steps=1000 \ +++train_config.total_steps=100000 \ +++train_config.lr=1e-4 \ +++train_config.validation_interval=10000 \ +++train_config.batch_size_training=8 \ +++train_config.val_batch_size=4 \ +++train_config.num_workers_dataloader=2 \ +++train_config.output_dir=$output_dir \ +++train_config.use_peft=true \ +++peft_config.peft_method=llama_adapter \ +++metric=acc \ +++log_config.log_file=$output_dir/log.txt \ +" + +# -m debugpy --listen 5678 --wait-for-client +if [[ $CUDA_VISIBLE_DEVICES != *","* ]]; then + python -u $code_dir/finetune_seld.py \ + --config-path "conf" \ + $hydra_args +else + torchrun \ + --nnodes 1 \ + --nproc_per_node 4 \ + --master_port=29503 \ + $code_dir/finetune_seld.py \ + --config-path "conf" \ + ++train_config.enable_fsdp=false \ + ++train_config.enable_ddp=true \ + ++train_config.use_fp16=true \ + $hydra_args +fi diff --git a/examples/seld_spatialsoundqa/seld_config.py b/examples/seld_spatialsoundqa/seld_config.py new file mode 100644 index 00000000..00f4c7bd --- /dev/null +++ b/examples/seld_spatialsoundqa/seld_config.py @@ -0,0 +1,114 @@ +from dataclasses import dataclass, field +from typing import Optional, List + +@dataclass +class ModelConfig: + file: str = "examples/seld_spatialsoundqa/model/slam_model_seld.py:model_factory" + llm_name: str = "vicuna-13b-v1.5" + llm_path: str = "PATH/to/LLAMA/7B" + llm_type: str = "decoder_only" + llm_dim: int = 4096 + + encoder_name: Optional[str] = None + encoder_ckpt: Optional[str] = None + encoder_projector: str = "q-former" + encoder_dim: int = 768 + +@dataclass +class PeftConfig: + peft_method: str = "lora" # None , llama_adapter, prefix + r: int = 8 + lora_alpha: int = 32 + target_modules: List = field(default_factory=lambda: [ "q_proj", "v_proj" ]) + bias: str = "none" + task_type: str = "CAUSAL_LM" + lora_dropout: float = 0.05 + inference_mode: bool = False + +@dataclass +class TrainConfig: + model_name: str = "vicuna-7b-v1.5" + enable_ddp:bool = False + enable_deepspeed:bool = False + enable_fsdp:bool = False + low_cpu_fsdp:bool = False + run_validation:bool = True + batch_size_training:int = 4 + batching_strategy:str = field(default="packing", metadata={ + "help":"alternative: padding" + }) # + context_length:int = 4096 + gradient_accumulation_steps:int = 1 + num_epochs:int = 3 + num_workers_dataloader:int = 1 + warmup_steps:int = 1000 + total_steps:int = 100000 + validation_interval:int = 1000 + lr:float = 1e-4 + weight_decay:float = 0.0 + gamma:float = 0.85 + seed:int = 42 + use_fp16:bool = False + mixed_precision:bool = True + val_batch_size:int = 1 + + use_peft:bool = False + peft_config:PeftConfig = field(default_factory=PeftConfig) + output_dir:str = "PATH/to/save/PEFT/model" + freeze_layers:bool = False + num_freeze_layers:int = 1 + quantization:bool = False + one_gpu:bool = False + save_model:bool = True + dist_checkpoint_root_folder:str = "PATH/to/save/FSDP/model" # will be used if using FSDP + dist_checkpoint_folder:str = "fine-tuned" # will be used if using FSDP + save_optimizer:bool = False # will be used if using FSDP + use_fast_kernels:bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels + run_test_during_validation:bool = False + run_test_during_validation_file:str = "test.wav" + run_test_during_validation_prompt:str = "" + freeze_llm:bool = field(default=False, metadata={ + "help": "whether to freeze llm when finetuning, should be true when use peft finetuning" + }) + freeze_encoder:bool = False + +@dataclass +class DataConfig: + dataset: str = "spatial_audio_dataset" + file: str = "src/slam_llm/datasets/spatial_audio_dataset.py:get_spatial_audio_dataset" + ext_audio: str = ".wav" + train_split: str = "train" + test_split: str = "eval" + + stage: Optional[str] = None + + qa_data_root: Optional[str] = None + anechoic_data_root: Optional[str] = None + reverb_data_root: Optional[str] = None + channel_type: str = "binaural" + normalize: bool = True + max_words: Optional[int] = None + fix_length_audio: Optional[int] = None + inference_mode: bool = False + +@dataclass +class FSDPConfig: + mixed_precision: bool = True + use_fp16: bool = False + # sharding_strategy = "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD + sharding_strategy: str = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP + checkpoint_type: str = "SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size. + fsdp_activation_checkpointing: bool = True + fsdp_cpu_offload: bool = False + pure_bf16: bool = False + optimizer: str = "AdamW" + +@dataclass +class LogConfig: + use_wandb: bool = False + wandb_dir: str = "/root/test_wandb" + wandb_entity_name: str = "project_name" + wandb_project_name: str = "project_name" + wandb_exp_name: str = "exp_name" + log_file: str = "/root/test.log" + log_interval: int = 5 diff --git a/src/slam_llm/datasets/spatial_audio_dataset.py b/src/slam_llm/datasets/spatial_audio_dataset.py new file mode 100644 index 00000000..97775243 --- /dev/null +++ b/src/slam_llm/datasets/spatial_audio_dataset.py @@ -0,0 +1,244 @@ +import os +import random +import json, yaml +import copy +import h5py + +import numpy as np +import soundfile as sf +from scipy import signal + +import torch +from torch.utils.data import Dataset + +def format_prompt(instruction, input=None): + PROMPT_DICT = { + "prompt_input": ( + "Below is an instruction that describes a task, paired with an input that provides further context. " + "Write a response that appropriately completes the request.\n\n" + "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:" + ), + "prompt_no_input": ( + "Based on the audio you've heard, refer to the instruction and provide a response.\n\n" + "### Instruction:\n{instruction}\n\n### Response:" + ), + } + if input is None: + return PROMPT_DICT['prompt_no_input'].format_map({'instruction': instruction}) + else: + return PROMPT_DICT["prompt_input"].format_map({'instruction': instruction, 'input': input}) + + +class SpatialAudioDatasetJsonl(Dataset): + def __init__( + self, + dataset_config, + tokenizer, + split, + ): + super().__init__() + dataset_path = os.path.join(dataset_config['qa_data_root'], dataset_config['stage'], split + '.jsonl') + with open(dataset_path) as f: + self.data = [json.loads(line) for line in f.readlines()] + + self.anechoic_data_root = dataset_config['anechoic_data_root'] # which is AudioSet in this case + self.reverb_data_root = dataset_config['reverb_data_root'] + self.channel_type = dataset_config['channel_type'] + + self.ext_audio = dataset_config['ext_audio'] + self.max_words = dataset_config['max_words'] + self.fix_length_audio = dataset_config.get("fix_length_audio", -1) + + self.tokenizer = tokenizer + + self.normalize = dataset_config['normalize'] + self.inference_mode = dataset_config['inference_mode'] + self.IGNORE_INDEX = -100 # The default setting in CrossEntropyLoss + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + sample = self.data[index] + + audio_path = os.path.join(self.anechoic_data_root, sample['audio_id'] + self.ext_audio) + reverb_path = os.path.join(self.reverb_data_root, self.channel_type, sample['reverb_id']) + + if sample['audio_id2'] is not None and sample['reverb_id2'] is not None: + audio_path2 = os.path.join(self.anechoic_data_root, sample['audio_id2'] + self.ext_audio) + reverb_path2 = os.path.join(self.reverb_data_root, self.channel_type, sample['reverb_id2']) + else: + audio_path2 = None + reverb_path2 = None + + waveforms = self.load_waveform(audio_path, reverb_path, audio_path2, reverb_path2) + + if self.fix_length_audio > 0: + audio_length = self.fix_length_audio + audio_pseudo = torch.full((audio_length,), -1) # placeholder + + prompt = sample['question'] + prompt = format_prompt(prompt, None) + prompt_ids = self.tokenizer.encode(prompt) + prompt_length = len(prompt_ids) + + if self.inference_mode: + prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64) + example_ids = torch.cat((audio_pseudo, prompt_ids)) # [audio,prompt] + example_mask = example_ids.ge(-1) # [True,True] + + return { + "input_ids": example_ids, + "attention_mask": example_mask, + "audio": waveforms, + "audio_length": audio_length, + "key": f"{sample['question_type']}-{sample['question_id']}", + "target": sample['answer'], + } + + answer = sample['answer'] + example = prompt + answer + + example_ids = self.tokenizer.encode(example) # [prompt,answer] + example_ids.append(self.tokenizer.eos_token_id) # [prompt,answer,eos] + example_ids = torch.tensor(example_ids, dtype=torch.int64) + example_ids = torch.cat((audio_pseudo, example_ids)) # [audio,prompt,answer,eos] + + labels_ids = copy.deepcopy(example_ids) # [audio,prompt,answer,eos] + labels_ids[:audio_length + prompt_length] = -1 # [-1,-1,answer,eos]; + example_mask = example_ids.ge(-1) # FIX(GZF): [True,True,True,True] + + label_mask = labels_ids.ge(0) # [False,False,True,True] + example_ids[~example_mask] = 0 # [audio,prompt,answer,eos] + labels_ids[~label_mask] = self.IGNORE_INDEX # [-100,-100,answer,eos] + + return { + "input_ids": example_ids, + "labels": labels_ids, + "attention_mask": example_mask, + "audio": waveforms, + "audio_length": audio_length + } + + @classmethod + def normalize_audio(cls, audio_data, target_dBFS=-14.0): + rms = np.sqrt(np.mean(audio_data**2)) # Calculate the RMS of the audio + + if rms == 0: # Avoid division by zero in case of a completely silent audio + return audio_data + + current_dBFS = 20 * np.log10(rms) # Convert RMS to dBFS + gain_dB = target_dBFS - current_dBFS # Calculate the required gain in dB + gain_linear = 10 ** (gain_dB / 20) # Convert gain from dB to linear scale + normalized_audio = audio_data * gain_linear # Apply the gain to the audio data + return normalized_audio + + @classmethod + def load_waveform(cls, audio_path, reverb_path=None, audio_path2=None, reverb_path2=None, normalize=True): + waveform, sr = sf.read(audio_path) + + if len(waveform.shape) > 1: + waveform = waveform[:, 0] + if sr != 32000: + waveform = signal.resample_poly(waveform, 32000, sr) + sr = 32000 + if normalize: + waveform = cls.normalize_audio(waveform, -14.0) + + waveform = waveform.reshape(1, -1) + if reverb_path is not None: + reverb = np.load(reverb_path) + waveform = signal.fftconvolve(waveform, reverb, mode='full') + + waveform = torch.from_numpy(waveform).float() + waveform = cls.padding(waveform, max_length=10*sr) + + if audio_path2 is not None and reverb_path2 is not None: + waveform2, sr2 = sf.read(audio_path2) + + if len(waveform2.shape) > 1: + waveform2 = waveform2[:, 0] + if sr2 != 32000: + waveform2 = signal.resample_poly(waveform2, 32000, sr2) + sr2 = 32000 + if normalize: + waveform2 = cls.normalize_audio(waveform2, -14.0) + + waveform2 = waveform2.reshape(1, -1) + reverb2 = np.load(reverb_path2) + waveform2 = signal.fftconvolve(waveform2, reverb2, mode='full') + waveform2 = torch.from_numpy(waveform2).float() + waveform2 = cls.padding(waveform2, max_length=10*sr) + + waveform = (waveform + waveform2) / 2 + return waveform + + @classmethod + def padding(cls, sequence, max_length, padding_idx=0): + if isinstance(sequence, (int, list, tuple)): + if len(sequence) < max_length: + sequence = sequence + [padding_idx] * (max_length - len(sequence)) + else: + sequence = sequence[:max_length] + elif isinstance(sequence, torch.Tensor): + if sequence.ndimension() == 2: + if sequence.shape[1] < max_length: + sequence = torch.nn.functional.pad(sequence, (0, max_length - sequence.shape[1])) + else: + sequence = sequence[:, :max_length] + else: + if len(sequence) < max_length: + sequence = torch.cat( + (sequence, torch.full(([max_length - len(sequence)] + list(sequence.size())[1:]), padding_idx))) + else: + sequence = sequence[:max_length] + elif isinstance(sequence, np.ndarray): + if len(sequence) < max_length: + sequence = np.concatenate( + (sequence, np.full((max_length - len(sequence),) + sequence.shape[1:], padding_idx))) + else: + sequence = sequence[:max_length] + else: + raise Exception("Type mismatch during padding!") + return sequence + + def collator(self, samples): + assert samples is not None + input_ids_max_length = max([s['input_ids'].shape[0] for s in samples]) + input_ids = torch.stack([ + self.padding(s['input_ids'], input_ids_max_length, self.tokenizer.pad_token_id) for s in samples]) + attention_mask = torch.stack([ + self.padding(s['attention_mask'], input_ids_max_length, False) for s in samples]) + + audio = torch.stack([s['audio'] for s in samples]) + + modality_mask = torch.zeros_like(attention_mask) + for line, sample in enumerate(samples): + modality_mask[line, :sample['audio_length']] = 1 + + if self.inference_mode: + keys = [s['key'] for s in samples] + targets = [s['target'] for s in samples] + + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "audio": audio, + "modality_mask": modality_mask, + "keys": keys, + "targets": targets + } + + labels = torch.stack([self.padding(s['labels'], input_ids_max_length, self.IGNORE_INDEX) + for s in samples]) + return { + "input_ids": input_ids, + "labels": labels, + "attention_mask": attention_mask, + "audio": audio, + "modality_mask": modality_mask + } + +def get_spatial_audio_dataset(dataset_config, tokenizer, split): + dataset = SpatialAudioDatasetJsonl(dataset_config, tokenizer, split) + return dataset \ No newline at end of file diff --git a/src/slam_llm/models/SpatialAST/SpatialAST.py b/src/slam_llm/models/SpatialAST/SpatialAST.py new file mode 100644 index 00000000..ae750a39 --- /dev/null +++ b/src/slam_llm/models/SpatialAST/SpatialAST.py @@ -0,0 +1,122 @@ +import torch +import torch.nn as nn + +from torchlibrosa.stft import STFT, LogmelFilterBank +from timm.models.layers import to_2tuple + +from .vision_transformer import VisionTransformer as _VisionTransformer + +def conv3x3(in_channels, out_channels, stride=1): + "3x3 convolution with padding" + return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False) + +class PatchEmbed_new(nn.Module): + """ Flexible Image to Patch Embedding + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, stride=10): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + stride = to_2tuple(stride) + + self.img_size = img_size + self.patch_size = patch_size + self.in_chans = in_chans + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride) # with overlapped patches + _, _, h, w = self.get_output_shape(img_size) # n, emb_dim, h, w + self.patch_hw = (h, w) + self.num_patches = h*w + + def get_output_shape(self, img_size): + return self.proj(torch.randn(1, self.in_chans, img_size[0], img_size[1])).shape + + def forward(self, x): + B, C, H, W = x.shape + + x = self.proj(x) # 32, 1, 1024, 128 -> 32, 768, 101, 12 + x = x.flatten(2) # 32, 768, 101, 12 -> 32, 768, 1212 + x = x.transpose(1, 2) # 32, 768, 1212 -> 32, 1212, 768 + return x + +class BinauralEncoder(_VisionTransformer): + """ Spatial Audio Spectrogram Transformer designed for Sound Event Localization and Detection + -------------------------------------------------------- + References: + Spatial-AST from BAT: https://github.com/zszheng147/Spatial-AST and https://arxiv.org/abs/2402.01591 + -------------------------------------------------------- + """ + def __init__(self, num_cls_tokens=3, **kwargs): + super(BinauralEncoder, self).__init__(**kwargs) + img_size = (1024, 128) # 1024, 128 + in_chans = 1 + emb_dim = 768 + + del self.cls_token + self.num_cls_tokens = num_cls_tokens + self.cls_tokens = nn.Parameter(torch.zeros(1, num_cls_tokens, emb_dim)) + + self.patch_embed = PatchEmbed_new( + img_size=img_size, patch_size=(16, 16), + in_chans=in_chans, embed_dim=emb_dim, stride=16 + ) # no overlap. stride=img_size=16 + + num_patches = self.patch_embed.num_patches + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, emb_dim), requires_grad=False) # fixed sin-cos embedding + + self.spectrogram_extractor = STFT( + n_fft=1024, hop_length=320, win_length=1024, window='hann', + center=True, pad_mode='reflect', freeze_parameters=True + ) + self.logmel_extractor = LogmelFilterBank( + sr=32000, n_fft=1024, n_mels=128, fmin=50, + fmax=14000, ref=1.0, amin=1e-10, top_db=None, freeze_parameters=True + ) + + self.conv_downsample = nn.Sequential( + conv3x3(4, 1), + nn.BatchNorm2d(1), + nn.GELU(), + ) + + self.bn = nn.BatchNorm2d(2, affine=False) + del self.norm # remove the original norm + + self.target_frame = 1024 + + def forward_features_mask(self, x): + B = x.shape[0] #bsz, 512, 768 (unmasked) + + x = x + self.pos_embed[:, 1:, :] + + cls_tokens = self.cls_tokens + cls_tokens = cls_tokens.expand(B, -1, -1) + x = torch.cat([cls_tokens, x], dim=1) # bsz, 512 + 2 + 10, 768 + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + + return x + + @torch.no_grad() + def forward(self, waveforms): + B, C, T = waveforms.shape + + waveforms = waveforms.reshape(B * C, T) + real, imag = self.spectrogram_extractor(waveforms) + + log_mel = self.logmel_extractor(torch.sqrt(real**2 + imag**2)).reshape(B, C, -1, 128) + log_mel = self.bn(log_mel) + + IPD = torch.atan2(imag[1::2], real[1::2]) - torch.atan2(imag[::2], real[::2]) + x = torch.cat([log_mel, torch.matmul(torch.cat([torch.cos(IPD), torch.sin(IPD)], dim=1), self.logmel_extractor.melW)], dim=1) + + if x.shape[2] < self.target_frame: + x = nn.functional.interpolate(x, (self.target_frame, x.shape[3]), mode="bicubic", align_corners=True) + + x = self.conv_downsample(x) + x = self.patch_embed(x) + x = self.forward_features_mask(x) + + return x \ No newline at end of file diff --git a/src/slam_llm/models/SpatialAST/vision_transformer.py b/src/slam_llm/models/SpatialAST/vision_transformer.py new file mode 100644 index 00000000..9c1fc115 --- /dev/null +++ b/src/slam_llm/models/SpatialAST/vision_transformer.py @@ -0,0 +1,239 @@ +import torch +from torch import nn + +from timm.models.layers import to_2tuple, DropPath, trunc_normal_ + + +class HybridEmbed(nn.Module): + """ CNN Feature Map Embedding + Extract feature map from CNN, flatten, project to embedding dim. + """ + def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768): + super().__init__() + assert isinstance(backbone, nn.Module) + img_size = to_2tuple(img_size) + self.img_size = img_size + self.backbone = backbone + if feature_size is None: + with torch.no_grad(): + # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature + # map for all networks, the feature metadata has reliable channel and stride info, but using + # stride to calc feature dim requires info about padding of each stage that isn't captured. + training = backbone.training + if training: + backbone.eval() + o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1] + feature_size = o.shape[-2:] + feature_dim = o.shape[1] + backbone.train(training) + else: + feature_size = to_2tuple(feature_size) + feature_dim = self.backbone.feature_info.channels()[-1] + self.num_patches = feature_size[0] * feature_size[1] + self.proj = nn.Linear(feature_dim, embed_dim) + + def forward(self, x): + x = self.backbone(x)[-1] + x = x.flatten(2).transpose(1, 2) + x = self.proj(x) + return x + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) + return x + + +class PatchEmbed_new(nn.Module): + """ Flexible Image to Patch Embedding + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, stride=10): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + stride = to_2tuple(stride) + + self.img_size = img_size + self.patch_size = patch_size + self.in_chans = in_chans + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride) # with overlapped patches + _, _, h, w = self.get_output_shape(img_size) # n, emb_dim, h, w + self.patch_hw = (h, w) + self.num_patches = h*w + + def get_output_shape(self, img_size): + return self.proj(torch.randn(1, self.in_chans, img_size[0], img_size[1])).shape + + def forward(self, x): + B, C, H, W = x.shape + + x = self.proj(x) # 32, 1, 1024, 128 -> 32, 768, 101, 12 + x = x.flatten(2) # 32, 768, 101, 12 -> 32, 768, 1212 + x = x.transpose(1, 2) # 32, 768, 1212 -> 32, 1212, 768 + return x + + +class VisionTransformer(nn.Module): + """ Vision Transformer with support for patch or hybrid CNN input stage + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., + drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm): + super().__init__() + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + + if hybrid_backbone is not None: + self.patch_embed = HybridEmbed( + hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim) + else: + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) + for i in range(depth)]) + + self.norm = norm_layer(embed_dim) + + # NOTE as per official impl, we could have a pre-logits representation dense layer + tanh here + #self.repr = nn.Linear(embed_dim, representation_size) + #self.repr_act = nn.Tanh() + + # Classifier head + self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + trunc_normal_(self.pos_embed, std=.02) + trunc_normal_(self.cls_token, std=.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + B = x.shape[0] + x = self.patch_embed(x) + + cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + x = x + self.pos_embed + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + return x[:, 0] + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x \ No newline at end of file diff --git a/src/slam_llm/models/encoder.py b/src/slam_llm/models/encoder.py index 0b2bbff3..bf34db6f 100644 --- a/src/slam_llm/models/encoder.py +++ b/src/slam_llm/models/encoder.py @@ -66,6 +66,20 @@ def load(cls, model_config): def extract_features(self, source, padding_mask): return self.model.extract_features(source, padding_mask = padding_mask, mask=False, remove_extra_tokens = False)['x'] +class SpatialASTEncoder: + @classmethod + def load(cls, model_config): + from functools import partial + from .SpatialAST import SpatialAST + binaural_encoder = SpatialAST.BinauralEncoder( + num_classes=355, drop_path_rate=0.1, num_cls_tokens=3, + patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6) + ) + + checkpoint = torch.load(model_config.encoder_ckpt, map_location='cpu') + binaural_encoder.load_state_dict(checkpoint['model'], strict=False) + return binaural_encoder class WavLMEncoder(nn.Module): def __init__(self, config, model): diff --git a/src/slam_llm/models/slam_model.py b/src/slam_llm/models/slam_model.py index 38d4b471..a3cbfe80 100644 --- a/src/slam_llm/models/slam_model.py +++ b/src/slam_llm/models/slam_model.py @@ -308,6 +308,8 @@ def forward(self, encoder_outs, audio_mel_post_mask = self.encoder.extract_features(audio_mel, audio_mel_mask) # bs*seq*dim if self.model_config.encoder_name == "eat": encoder_outs = self.encoder.model.extract_features(audio_mel.unsqueeze(dim=1), padding_mask = None, mask=False, remove_extra_tokens = False)['x'] + if self.model_config.encoder_name == "SpatialAST": + encoder_outs = self.encoder(audio) # output: [bs, seq_len=3+512, dim=768] if self.model_config.encoder_name == "wavlm": encoder_outs = self.encoder.extract_features(audio, 1 - audio_mask) #(FIX:MZY): 1-audio_mask is needed for wavlm as the padding mask if self.model_config.encoder_name == "hubert":