-
Notifications
You must be signed in to change notification settings - Fork 56
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #74 from ddlBoJack/bat
Bat
- Loading branch information
Showing
16 changed files
with
1,206 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
# <img src="assets/bat.png" alt="SELD_SpatialSoundQA" width="25" height="25"> SELD_SpatialSoundQA | ||
|
||
This repo hosts the code and models of "[BAT: Learning to Reason about Spatial Sounds with Large Language Models](https://arxiv.org/abs/2402.01591)" [ICML 2024 [bib](https://github.com/zszheng147/Spatial-AST#citation)]. | ||
|
||
Checkout our [demo page](https://zhishengzheng.com/BAT/) and enjoy a QA game with spatial audio. | ||
|
||
## Performance and checkpoints | ||
Encoder | Projector | PEFT | LLM | ||
|---|---|---|---| | ||
[Spatial-AST](https://huggingface.co/zhisheng01/Bat/blob/main/spatial-ast.pth) | Q-Former | adapter |[llama-2-7b](https://huggingface.co/meta-llama/Llama-2-7b) | ||
|
||
## Data preparation | ||
You need to prepare the data jsonl in this format. Below is an example. | ||
You can download the SpatialSoundQA dataset from [huggingface](https://huggingface.co/datasets/zhisheng01/SpatialSoundQA). | ||
``` | ||
{"audio_id": "eval/audio/YI-HlrcP6Qg4", "reverb_id": "q9vSo1VnCiC/0.npy", "audio_id2": null, "reverb_id2": null, "question_id": 0, "question_type": "CLASSIFICATION", "question": "Enumerate the sound occurrences in the audio clip.", "answer": "accelerating, revving, vroom; car; vehicle"} | ||
... | ||
{"audio_id": "eval/audio/YZX2fVPmUidA", "reverb_id": "q9vSo1VnCiC/32.npy", "audio_id2": "eval/audio/YjNjUU01quLs", "reverb_id2": "q9vSo1VnCiC/31.npy", "question_id": 58, "question_type": "MIXUP_NONBINARY_DISTANCE", "question": "How far away is the sound of the banjo from the sound of the whack, thwack?", "answer": "2m"} | ||
``` | ||
|
||
## Train a new model | ||
```bash | ||
bash examples/seld_spatialsoundqa/scripts/finetune_spatial-ast_qformer_llama_2_7b.sh | ||
``` | ||
|
||
## Decoding with checkpoints | ||
```bash | ||
bash examples/seld_spatialsoundqa/scripts/decode_spatial-ast_qformer_llama_2_7b.sh | ||
``` | ||
|
||
|
||
## TODO | ||
- [x] Decode with checkpoints | ||
- [x] Upload SpatialSoundQA dataset | ||
- [ ] Upload pretrained checkpoints | ||
- [ ] Update model performance | ||
|
||
## Citation | ||
``` | ||
@article{zheng2024bat, | ||
author = {Zheng, Zhisheng and Peng, Puyuan and Ma, Ziyang and Chen, Xie and Choi, Eunsol and Harwath, David}, | ||
title = {BAT: Learning to Reason about Spatial Sounds with Large Language Models}, | ||
journal = {arXiv preprint arXiv:2402.01591}, | ||
year = {2024}, | ||
} | ||
``` |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
{ | ||
"train_micro_batch_size_per_gpu": 4, | ||
"gradient_accumulation_steps": 1, | ||
"optimizer": { | ||
"type": "Adam", | ||
"params": { | ||
"lr": 1e-4 | ||
} | ||
}, | ||
"fp16": { | ||
"enabled": true | ||
}, | ||
"zero_optimization": { | ||
"stage": 3, | ||
"offload_optimizer": { | ||
"device": "cpu" | ||
} | ||
} | ||
} |
152 changes: 152 additions & 0 deletions
152
examples/seld_spatialsoundqa/dataset/spatial_audio_dataset.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
import os | ||
import random | ||
import json | ||
import copy | ||
|
||
import numpy as np | ||
import soundfile as sf | ||
from scipy import signal | ||
|
||
import torch | ||
|
||
from slam_llm.datasets.base_dataset import BaseDataset | ||
|
||
def format_prompt(instruction, input=None): | ||
PROMPT_DICT = { | ||
"prompt_input": ( | ||
"Below is an instruction that describes a task, paired with an input that provides further context. " | ||
"Write a response that appropriately completes the request.\n\n" | ||
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:" | ||
), | ||
"prompt_no_input": ( | ||
"Based on the audio you've heard, refer to the instruction and provide a response.\n\n" | ||
"### Instruction:\n{instruction}\n\n### Response:" | ||
), | ||
} | ||
if input is None: | ||
return PROMPT_DICT['prompt_no_input'].format_map({'instruction': instruction}) | ||
else: | ||
return PROMPT_DICT["prompt_input"].format_map({'instruction': instruction, 'input': input}) | ||
|
||
|
||
class SpatialAudioDatasetJsonl(BaseDataset): | ||
def __init__( | ||
self, | ||
dataset_config, | ||
tokenizer, | ||
split, | ||
): | ||
super().__init__() | ||
dataset_path = os.path.join(dataset_config['qa_data_root'], dataset_config['stage'], split + '.jsonl') | ||
with open(dataset_path) as f: | ||
self.data = [json.loads(line) for line in f.readlines()] | ||
|
||
self.anechoic_data_root = dataset_config['anechoic_data_root'] # which is AudioSet in this case | ||
self.reverb_data_root = dataset_config['reverb_data_root'] | ||
self.channel_type = dataset_config['channel_type'] | ||
|
||
self.ext_audio = dataset_config['ext_audio'] | ||
self.max_words = dataset_config['max_words'] | ||
self.fix_length_audio = dataset_config.get("fix_length_audio", -1) | ||
|
||
self.tokenizer = tokenizer | ||
|
||
self.normalize = dataset_config['normalize'] | ||
self.inference_mode = dataset_config['inference_mode'] | ||
|
||
def __len__(self): | ||
return len(self.data) | ||
|
||
def __getitem__(self, index): | ||
sample = self.data[index] | ||
|
||
audio_path = os.path.join(self.anechoic_data_root, sample['audio_id'] + self.ext_audio) | ||
reverb_path = os.path.join(self.reverb_data_root, self.channel_type, sample['reverb_id']) | ||
|
||
if sample['audio_id2'] is not None and sample['reverb_id2'] is not None: | ||
audio_path2 = os.path.join(self.anechoic_data_root, sample['audio_id2'] + self.ext_audio) | ||
reverb_path2 = os.path.join(self.reverb_data_root, self.channel_type, sample['reverb_id2']) | ||
else: | ||
audio_path2 = None | ||
reverb_path2 = None | ||
|
||
waveforms = self.load_waveform(audio_path, reverb_path, audio_path2, reverb_path2) | ||
|
||
prompt = sample['question'] | ||
prompt = format_prompt(prompt, None) | ||
answer = sample['answer'] | ||
|
||
if not self.inference_mode: | ||
return super().__getitem__((waveforms, None, prompt, answer)) | ||
else: | ||
base_sample = super().__getitem__((waveforms, None, prompt, answer)) | ||
base_sample.update({ | ||
"key": f"{sample['question_type']}-{sample['question_id']}", | ||
"target": sample['answer'] | ||
}) | ||
return base_sample | ||
|
||
@classmethod | ||
def normalize_audio(cls, audio_data, target_dBFS=-14.0): | ||
rms = np.sqrt(np.mean(audio_data**2)) # Calculate the RMS of the audio | ||
|
||
if rms == 0: # Avoid division by zero in case of a completely silent audio | ||
return audio_data | ||
|
||
current_dBFS = 20 * np.log10(rms) # Convert RMS to dBFS | ||
gain_dB = target_dBFS - current_dBFS # Calculate the required gain in dB | ||
gain_linear = 10 ** (gain_dB / 20) # Convert gain from dB to linear scale | ||
normalized_audio = audio_data * gain_linear # Apply the gain to the audio data | ||
return normalized_audio | ||
|
||
@classmethod | ||
def load_waveform(cls, audio_path, reverb_path=None, audio_path2=None, reverb_path2=None, normalize=True): | ||
waveform, sr = sf.read(audio_path) | ||
|
||
if len(waveform.shape) > 1: | ||
waveform = waveform[:, 0] | ||
if sr != 32000: | ||
waveform = signal.resample_poly(waveform, 32000, sr) | ||
sr = 32000 | ||
if normalize: | ||
waveform = cls.normalize_audio(waveform, -14.0) | ||
|
||
waveform = waveform.reshape(1, -1) | ||
if reverb_path is not None: | ||
reverb = np.load(reverb_path) | ||
waveform = signal.fftconvolve(waveform, reverb, mode='full') | ||
|
||
waveform = torch.from_numpy(waveform).float() | ||
waveform = cls.padding(waveform, padding_length=10*sr-waveform.shape[1]) | ||
|
||
if audio_path2 is not None and reverb_path2 is not None: | ||
waveform2, sr2 = sf.read(audio_path2) | ||
|
||
if len(waveform2.shape) > 1: | ||
waveform2 = waveform2[:, 0] | ||
if sr2 != 32000: | ||
waveform2 = signal.resample_poly(waveform2, 32000, sr2) | ||
sr2 = 32000 | ||
if normalize: | ||
waveform2 = cls.normalize_audio(waveform2, -14.0) | ||
|
||
waveform2 = waveform2.reshape(1, -1) | ||
reverb2 = np.load(reverb_path2) | ||
waveform2 = signal.fftconvolve(waveform2, reverb2, mode='full') | ||
waveform2 = torch.from_numpy(waveform2).float() | ||
waveform2 = cls.padding(waveform2, padding_length=10*sr-waveform2.shape[1]) | ||
|
||
waveform = (waveform + waveform2) / 2 | ||
return waveform | ||
|
||
def collator(self, samples): | ||
audio = torch.stack([s['audio'] for s in samples]) | ||
|
||
collated = super().collator(samples) | ||
collated['audio'] = audio | ||
|
||
return collated | ||
|
||
def get_spatial_audio_dataset(dataset_config, tokenizer, split): | ||
dataset = SpatialAudioDatasetJsonl(dataset_config, tokenizer, split) | ||
return dataset |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import hydra | ||
import logging | ||
from dataclasses import dataclass, field | ||
from omegaconf import DictConfig, ListConfig, OmegaConf | ||
|
||
from seld_config import ModelConfig, TrainConfig, DataConfig, LogConfig, FSDPConfig, PeftConfig | ||
from slam_llm.pipeline.finetune import main as train | ||
|
||
@dataclass | ||
class RunConfig: | ||
dataset_config: DataConfig = field(default_factory=DataConfig) | ||
model_config: ModelConfig = field(default_factory=ModelConfig) | ||
train_config: TrainConfig = field(default_factory=TrainConfig) | ||
log_config: LogConfig = field(default_factory=LogConfig) | ||
fsdp_config: FSDPConfig = field(default_factory=FSDPConfig) | ||
peft_config: PeftConfig = field(default_factory=PeftConfig) | ||
debug: bool = field(default=False, metadata={"help": "Use pdb when true"}) | ||
metric: str = field(default="acc", metadata={"help": "The metric for evaluation"}) | ||
ckpt_path: str = field( | ||
default="output/model.pt", metadata={"help": "The path to projector checkpoint"} | ||
) | ||
|
||
@hydra.main(config_name=None, version_base=None) | ||
def main_hydra(cfg: DictConfig): | ||
run_config = RunConfig() | ||
cfg = OmegaConf.merge(run_config, cfg) | ||
def to_plain_list(cfg_item): | ||
if isinstance(cfg_item, ListConfig): | ||
return OmegaConf.to_container(cfg_item, resolve=True) | ||
elif isinstance(cfg_item, DictConfig): | ||
return {k: to_plain_list(v) for k, v in cfg_item.items()} | ||
else: | ||
return cfg_item | ||
|
||
# kwargs = to_plain_list(cfg) | ||
kwargs = cfg | ||
log_level = getattr(logging, kwargs.get("log_level", "INFO").upper()) | ||
logging.basicConfig(level=log_level) | ||
|
||
if kwargs.get("debug", False): | ||
import pdb; | ||
pdb.set_trace() | ||
|
||
train(kwargs) | ||
|
||
|
||
if __name__ == "__main__": | ||
main_hydra() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import hydra | ||
import logging | ||
from dataclasses import dataclass, field | ||
from omegaconf import DictConfig, ListConfig, OmegaConf | ||
from typing import Optional | ||
|
||
from slam_llm.pipeline.inference_batch import main as inference | ||
from seld_config import ModelConfig, TrainConfig, DataConfig, LogConfig, FSDPConfig, PeftConfig | ||
|
||
@dataclass | ||
class RunConfig: | ||
dataset_config: DataConfig = field(default_factory=DataConfig) | ||
model_config: ModelConfig = field(default_factory=ModelConfig) | ||
train_config: TrainConfig = field(default_factory=TrainConfig) | ||
log_config: LogConfig = field(default_factory=LogConfig) | ||
fsdp_config: FSDPConfig = field(default_factory=FSDPConfig) | ||
peft_config: PeftConfig = field(default_factory=PeftConfig) | ||
debug: bool = field(default=False, metadata={"help": "Use pdb when true"}) | ||
metric: str = field(default="acc", metadata={"help": "The metric for evaluation"}) | ||
decode_log: str = field( | ||
default="output/decode_log", | ||
metadata={"help": "The prefix for the decode output"}, | ||
) | ||
ckpt_path: str = field( | ||
default="output/model.pt", metadata={"help": "The path to projector checkpoint"} | ||
) | ||
peft_ckpt: Optional[str] = field( | ||
default=None, | ||
metadata={ | ||
"help": "The path to peft checkpoint, should be a directory including adapter_config.json" | ||
}, | ||
) | ||
|
||
|
||
@hydra.main(config_name=None, version_base=None) | ||
def main_hydra(cfg: DictConfig): | ||
run_config = RunConfig() | ||
cfg = OmegaConf.merge(run_config, cfg) | ||
# kwargs = to_plain_list(cfg) | ||
log_level = getattr(logging, cfg.get("log_level", "INFO").upper()) | ||
|
||
logging.basicConfig(level=log_level) | ||
|
||
if cfg.get("debug", False): | ||
import pdb | ||
|
||
pdb.set_trace() | ||
|
||
inference(cfg) | ||
|
||
|
||
if __name__ == "__main__": | ||
main_hydra() |
Oops, something went wrong.