-
Notifications
You must be signed in to change notification settings - Fork 57
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add BAT branch #69
add BAT branch #69
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
# <img src="assets/bat.png" alt="SELD_SpatialSoundQA" width="25" height="25"> SELD_SpatialSoundQA | ||
|
||
This repo hosts the code and models of "[BAT: Learning to Reason about Spatial Sounds with Large Language Models](https://arxiv.org/abs/2402.01591)" [ICML 2024 [bib](https://github.com/zszheng147/Spatial-AST#citation)]. | ||
|
||
## Performance and checkpoints | ||
Encoder | Projector | PEFT | LLM | ||
|---|---|---|---| | ||
[Spatial-AST](https://huggingface.co/zhisheng01/Bat/blob/main/spatial-ast.pth) | Q-Former | adapter |[llama-2-7b](https://huggingface.co/meta-llama/Llama-2-7b) | ||
|
||
## Data preparation | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. example -- > for more, link to demo page |
||
You need to prepare the data jsonl in this format. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Link to dataset page |
||
``` | ||
{"audio_id": "eval/audio/YI-HlrcP6Qg4", "reverb_id": "q9vSo1VnCiC/0.npy", "audio_id2": null, "reverb_id2": null, "question_id": 0, "question_type": "CLASSIFICATION", "question": "Enumerate the sound occurrences in the audio clip.", "answer": "accelerating, revving, vroom; car; vehicle"} | ||
... | ||
{"audio_id": "eval/audio/YZX2fVPmUidA", "reverb_id": "q9vSo1VnCiC/32.npy", "audio_id2": "eval/audio/YjNjUU01quLs", "reverb_id2": "q9vSo1VnCiC/31.npy", "question_id": 58, "question_type": "MIXUP_NONBINARY_DISTANCE", "question": "How far away is the sound of the banjo from the sound of the whack, thwack?", "answer": "2m"} | ||
``` | ||
|
||
## Train a new model | ||
```bash | ||
bash examples/seld_spatialsoundqa/scripts/finetune_spatial-ast_linear_llama_2_7b.sh | ||
``` | ||
|
||
## Decoding with checkpoints | ||
```bash | ||
bash examples/seld_spatialsoundqa/scripts/decode_spatial-ast_linear_llama_2_7b.sh | ||
``` | ||
|
||
|
||
## TODO | ||
- [x] Decode with checkpoints | ||
- [ ] Upload SpatialSoundQA dataset | ||
- [ ] Upload pretrained checkpoints | ||
- [ ] Update model performance | ||
|
||
## Citation | ||
``` | ||
@article{zheng2024bat, | ||
author = {Zheng, Zhisheng and Peng, Puyuan and Ma, Ziyang and Chen, Xie and Choi, Eunsol and Harwath, David}, | ||
title = {BAT: Learning to Reason about Spatial Sounds with Large Language Models}, | ||
journal = {arXiv preprint arXiv:2402.01591}, | ||
year = {2024}, | ||
} | ||
``` |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
{ | ||
"train_micro_batch_size_per_gpu": 4, | ||
"gradient_accumulation_steps": 1, | ||
"optimizer": { | ||
"type": "Adam", | ||
"params": { | ||
"lr": 1e-4 | ||
} | ||
}, | ||
"fp16": { | ||
"enabled": true | ||
}, | ||
"zero_optimization": { | ||
"stage": 3, | ||
"offload_optimizer": { | ||
"device": "cpu" | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import hydra | ||
import logging | ||
from dataclasses import dataclass, field | ||
from omegaconf import DictConfig, ListConfig, OmegaConf | ||
|
||
from seld_config import ModelConfig, TrainConfig, DataConfig, LogConfig, FSDPConfig, PeftConfig | ||
from slam_llm.pipeline.finetune import main as train | ||
|
||
@dataclass | ||
class RunConfig: | ||
dataset_config: DataConfig = field(default_factory=DataConfig) | ||
model_config: ModelConfig = field(default_factory=ModelConfig) | ||
train_config: TrainConfig = field(default_factory=TrainConfig) | ||
log_config: LogConfig = field(default_factory=LogConfig) | ||
fsdp_config: FSDPConfig = field(default_factory=FSDPConfig) | ||
peft_config: PeftConfig = field(default_factory=PeftConfig) | ||
debug: bool = field(default=False, metadata={"help": "Use pdb when true"}) | ||
metric: str = field(default="acc", metadata={"help": "The metric for evaluation"}) | ||
|
||
@hydra.main(config_name=None, version_base=None) | ||
def main_hydra(cfg: DictConfig): | ||
run_config = RunConfig() | ||
cfg = OmegaConf.merge(run_config, cfg) | ||
def to_plain_list(cfg_item): | ||
if isinstance(cfg_item, ListConfig): | ||
return OmegaConf.to_container(cfg_item, resolve=True) | ||
elif isinstance(cfg_item, DictConfig): | ||
return {k: to_plain_list(v) for k, v in cfg_item.items()} | ||
else: | ||
return cfg_item | ||
|
||
# kwargs = to_plain_list(cfg) | ||
kwargs = cfg | ||
log_level = getattr(logging, kwargs.get("log_level", "INFO").upper()) | ||
logging.basicConfig(level=log_level) | ||
|
||
if kwargs.get("debug", False): | ||
import pdb; | ||
pdb.set_trace() | ||
|
||
train(kwargs) | ||
|
||
|
||
if __name__ == "__main__": | ||
main_hydra() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import hydra | ||
import logging | ||
from dataclasses import dataclass, field | ||
from omegaconf import DictConfig, ListConfig, OmegaConf | ||
from typing import Optional | ||
|
||
from slam_llm.pipeline.inference_batch import main as inference | ||
from seld_config import ModelConfig, TrainConfig, DataConfig, LogConfig, FSDPConfig, PeftConfig | ||
|
||
@dataclass | ||
class RunConfig: | ||
dataset_config: DataConfig = field(default_factory=DataConfig) | ||
model_config: ModelConfig = field(default_factory=ModelConfig) | ||
train_config: TrainConfig = field(default_factory=TrainConfig) | ||
log_config: LogConfig = field(default_factory=LogConfig) | ||
fsdp_config: FSDPConfig = field(default_factory=FSDPConfig) | ||
peft_config: PeftConfig = field(default_factory=PeftConfig) | ||
debug: bool = field(default=False, metadata={"help": "Use pdb when true"}) | ||
metric: str = field(default="acc", metadata={"help": "The metric for evaluation"}) | ||
decode_log: str = field( | ||
default="output/decode_log", | ||
metadata={"help": "The prefix for the decode output"}, | ||
) | ||
ckpt_path: str = field( | ||
default="output/model.pt", metadata={"help": "The path to projector checkpoint"} | ||
) | ||
peft_ckpt: Optional[str] = field( | ||
default=None, | ||
metadata={ | ||
"help": "The path to peft checkpoint, should be a directory including adapter_config.json" | ||
}, | ||
) | ||
|
||
|
||
@hydra.main(config_name=None, version_base=None) | ||
def main_hydra(cfg: DictConfig): | ||
run_config = RunConfig() | ||
cfg = OmegaConf.merge(run_config, cfg) | ||
# kwargs = to_plain_list(cfg) | ||
log_level = getattr(logging, cfg.get("log_level", "INFO").upper()) | ||
|
||
logging.basicConfig(level=log_level) | ||
|
||
if cfg.get("debug", False): | ||
import pdb | ||
|
||
pdb.set_trace() | ||
|
||
inference(cfg) | ||
|
||
|
||
if __name__ == "__main__": | ||
main_hydra() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
import torch | ||
import os | ||
import logging | ||
from slam_llm.models.slam_model import ( | ||
slam_model, | ||
setup_tokenizer, | ||
setup_encoder, | ||
setup_encoder_projector, | ||
setup_llm, | ||
) | ||
from slam_llm.utils.train_utils import print_model_size | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
def model_factory(train_config, model_config, **kwargs): | ||
# return necessary components for training | ||
tokenizer = setup_tokenizer(train_config, model_config, **kwargs) | ||
|
||
encoder = setup_encoder(train_config, model_config, **kwargs) | ||
|
||
# llm | ||
llm = setup_llm(train_config, model_config, **kwargs) | ||
|
||
# projector | ||
encoder_projector = setup_encoder_projector( | ||
train_config, model_config, **kwargs | ||
) | ||
model = slam_model_seld( | ||
encoder, | ||
llm, | ||
encoder_projector, | ||
tokenizer, | ||
train_config, | ||
model_config, | ||
**kwargs, | ||
) | ||
|
||
ckpt_path = kwargs.get( | ||
"ckpt_path", None | ||
) # FIX(MZY): load model ckpt(mainly projector, related to model_checkpointing/checkpoint_handler.py: save_model_checkpoint_peft) | ||
if ckpt_path is not None: | ||
logger.info("loading other parts from: {}".format(ckpt_path)) | ||
ckpt_dict = torch.load(ckpt_path, map_location="cpu") | ||
model.load_state_dict(ckpt_dict, strict=False) | ||
|
||
print_model_size( | ||
model, | ||
train_config, | ||
( | ||
int(os.environ["RANK"]) | ||
if train_config.enable_fsdp or train_config.enable_ddp | ||
else 0 | ||
), | ||
) | ||
return model, tokenizer | ||
|
||
|
||
class slam_model_seld(slam_model): | ||
def __init__( | ||
self, | ||
encoder, | ||
llm, | ||
encoder_projector, | ||
tokenizer, | ||
train_config, | ||
model_config, | ||
**kwargs, | ||
): | ||
super().__init__( | ||
encoder, | ||
llm, | ||
encoder_projector, | ||
tokenizer, | ||
train_config, | ||
model_config, | ||
**kwargs, | ||
) | ||
|
||
@torch.no_grad() | ||
def inference( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. to delate |
||
self, | ||
wav_path=None, | ||
prompt=None, | ||
generation_config=None, | ||
logits_processor=None, | ||
stopping_criteria=None, | ||
prefix_allowed_tokens_fn=None, | ||
synced_gpus=None, | ||
assistant_model=None, | ||
streamer=None, | ||
negative_prompt_ids=None, | ||
negative_prompt_attention_mask=None, | ||
**kwargs, | ||
): | ||
# inference for asr model | ||
|
||
device = kwargs.get("device", "cuda") | ||
if os.path.exists(wav_path): # Audio-Text QA | ||
import whisper | ||
|
||
audio_raw = whisper.load_audio(wav_path) | ||
audio_raw = whisper.pad_or_trim(audio_raw) | ||
|
||
mel_size = getattr( | ||
self.dataset_config, "mel_size", 80 | ||
) # 80 for large v1 and v2, 128 for large v3 | ||
audio_mel = ( | ||
whisper.log_mel_spectrogram(audio_raw, n_mels=mel_size) | ||
.permute(1, 0)[None, :, :] | ||
.to(device) | ||
) | ||
|
||
encoder_outs = self.encoder.extract_variable_length_features( | ||
audio_mel.permute(0, 2, 1) | ||
) | ||
|
||
if self.model_config.encoder_projector == "q-former": | ||
audio_mel_post_mask = torch.ones( | ||
encoder_outs.size()[:-1], dtype=torch.long | ||
).to(encoder_outs.device) | ||
encoder_outs = self.encoder_projector(encoder_outs, audio_mel_post_mask) | ||
if self.model_config.encoder_projector == "linear": | ||
encoder_outs = self.encoder_projector(encoder_outs) | ||
else: # Text QA | ||
encoder_outs = torch.empty( | ||
1, 0, self.llm.model.embed_tokens.embedding_dim | ||
).to(device) | ||
|
||
prompt = "USER: {}\n ASSISTANT:".format(prompt) | ||
prompt_ids = self.tokenizer.encode(prompt) | ||
prompt_length = len(prompt_ids) | ||
prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64).to(device) | ||
|
||
if hasattr(self.llm.model, "embed_tokens"): | ||
inputs_embeds = self.llm.model.embed_tokens(prompt_ids) | ||
elif hasattr(self.llm.model.model, "embed_tokens"): | ||
inputs_embeds = self.llm.model.model.embed_tokens(prompt_ids) | ||
else: | ||
inputs_embeds = self.llm.model.model.model.embed_tokens(prompt_ids) | ||
|
||
inputs_embeds = torch.cat( | ||
(encoder_outs, inputs_embeds[None, :, :]), dim=1 | ||
) # [audio,prompt] | ||
|
||
attention_mask = torch.ones(inputs_embeds.size()[:-1], dtype=torch.long).to( | ||
inputs_embeds.device | ||
) | ||
|
||
# generate | ||
model_outputs = self.generate( | ||
inputs_embeds=inputs_embeds, attention_mask=attention_mask, **kwargs | ||
) | ||
|
||
return model_outputs |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
#!/bin/bash | ||
#export PYTHONPATH=/root/whisper:$PYTHONPATH | ||
export CUDA_VISIBLE_DEVICES=0 | ||
export TOKENIZERS_PARALLELISM=false | ||
# export CUDA_LAUNCH_BLOCKING=1 | ||
|
||
SLAM_DIR=/mnt/cloudstorfs/sjtu_home/zhisheng.zheng/SLAM-LLM | ||
cd $SLAM_DIR | ||
code_dir=examples/seld_spatialsoundqa | ||
|
||
stage=classification | ||
qa_data_root=/mnt/cloudstorfs/sjtu_home/zhisheng.zheng/data/SpatialAudio/closed-end | ||
reverb_data_root=/mnt/cloudstorfs/sjtu_home/zhisheng.zheng/data/SpatialAudio/reverb/mp3d | ||
anechoic_data_root=/mnt/cloudstorfs/sjtu_home/zhisheng.zheng/data/AudioSet | ||
|
||
audio_encoder_path=/mnt/cloudstorfs/sjtu_home/zhisheng.zheng/models/SpatialAST/SpatialAST.pth | ||
llm_path=/mnt/cloudstorfs/sjtu_home/zhisheng.zheng/models/llama-2-hf | ||
|
||
split=eval | ||
# output_dir=/mnt/cloudstorfs/sjtu_home/zhisheng.zheng/SLAM-LLM/outputs/bat-vicuna-7b-v1.5-spatialAST-qformer-steplrwarmupkeep1e-4-${stage}-$(date +"%Y%m%d") | ||
output_dir=/mnt/cloudstorfs/sjtu_home/zhisheng.zheng/SLAM-LLM/outputs/bat-llama-2-spatialAST-qformer-steplrwarmupkeep1e-4-classification-20240507 | ||
ckpt_path=$output_dir/bat_epoch_2_step_2576 | ||
decode_log=$ckpt_path/decode_${split}_beam4 | ||
|
||
# -m debugpy --listen 5678 --wait-for-client | ||
python -u $code_dir/inference_seld_batch.py \ | ||
--config-path "conf" \ | ||
hydra.run.dir=$ckpt_path \ | ||
++model_config.llm_name=llama-2-7b \ | ||
++model_config.llm_path=$llm_path \ | ||
++model_config.llm_dim=4096 \ | ||
++model_config.encoder_name=SpatialAST \ | ||
++model_config.encoder_projector=q-former \ | ||
++model_config.encoder_ckpt=$audio_encoder_path \ | ||
++dataset_config.stage=$stage \ | ||
++dataset_config.qa_data_root=$qa_data_root \ | ||
++dataset_config.anechoic_data_root=$anechoic_data_root \ | ||
++dataset_config.reverb_data_root=$reverb_data_root \ | ||
++dataset_config.fix_length_audio=64 \ | ||
++dataset_config.inference_mode=true \ | ||
++train_config.model_name=bat \ | ||
++train_config.freeze_encoder=true \ | ||
++train_config.freeze_llm=true \ | ||
++train_config.batching_strategy=custom \ | ||
++train_config.num_epochs=1 \ | ||
++train_config.val_batch_size=8 \ | ||
++train_config.num_workers_dataloader=2 \ | ||
++train_config.output_dir=$output_dir \ | ||
++train_config.use_peft=true \ | ||
++peft_config.peft_method=llama_adapter \ | ||
++log_config.log_file=$output_dir/test.log \ | ||
++decode_log=$decode_log \ | ||
++ckpt_path=$ckpt_path/model.pt \ | ||
# ++peft_ckpt=$ckpt_path \ | ||
# ++train_config.use_peft=true \ | ||
# ++train_config.peft_config.r=32 \ | ||
# ++dataset_config.normalize=true \ | ||
# ++model_config.encoder_projector=q-former \ | ||
# ++dataset_config.fix_length_audio=64 \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: Add checkpoint