Skip to content

Commit

Permalink
Merge pull request #74 from ddlBoJack/bat
Browse files Browse the repository at this point in the history
Bat
  • Loading branch information
ddlBoJack authored May 20, 2024
2 parents 3877c50 + 89b70be commit c89b620
Show file tree
Hide file tree
Showing 16 changed files with 1,206 additions and 6 deletions.
46 changes: 46 additions & 0 deletions examples/seld_spatialsoundqa/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# <img src="assets/bat.png" alt="SELD_SpatialSoundQA" width="25" height="25"> SELD_SpatialSoundQA

This repo hosts the code and models of "[BAT: Learning to Reason about Spatial Sounds with Large Language Models](https://arxiv.org/abs/2402.01591)" [ICML 2024 [bib](https://github.com/zszheng147/Spatial-AST#citation)].

Checkout our [demo page](https://zhishengzheng.com/BAT/) and enjoy a QA game with spatial audio.

## Performance and checkpoints
Encoder | Projector | PEFT | LLM
|---|---|---|---|
[Spatial-AST](https://huggingface.co/zhisheng01/Bat/blob/main/spatial-ast.pth) | Q-Former | adapter |[llama-2-7b](https://huggingface.co/meta-llama/Llama-2-7b)

## Data preparation
You need to prepare the data jsonl in this format. Below is an example.
You can download the SpatialSoundQA dataset from [huggingface](https://huggingface.co/datasets/zhisheng01/SpatialSoundQA).
```
{"audio_id": "eval/audio/YI-HlrcP6Qg4", "reverb_id": "q9vSo1VnCiC/0.npy", "audio_id2": null, "reverb_id2": null, "question_id": 0, "question_type": "CLASSIFICATION", "question": "Enumerate the sound occurrences in the audio clip.", "answer": "accelerating, revving, vroom; car; vehicle"}
...
{"audio_id": "eval/audio/YZX2fVPmUidA", "reverb_id": "q9vSo1VnCiC/32.npy", "audio_id2": "eval/audio/YjNjUU01quLs", "reverb_id2": "q9vSo1VnCiC/31.npy", "question_id": 58, "question_type": "MIXUP_NONBINARY_DISTANCE", "question": "How far away is the sound of the banjo from the sound of the whack, thwack?", "answer": "2m"}
```

## Train a new model
```bash
bash examples/seld_spatialsoundqa/scripts/finetune_spatial-ast_qformer_llama_2_7b.sh
```

## Decoding with checkpoints
```bash
bash examples/seld_spatialsoundqa/scripts/decode_spatial-ast_qformer_llama_2_7b.sh
```


## TODO
- [x] Decode with checkpoints
- [x] Upload SpatialSoundQA dataset
- [ ] Upload pretrained checkpoints
- [ ] Update model performance

## Citation
```
@article{zheng2024bat,
author = {Zheng, Zhisheng and Peng, Puyuan and Ma, Ziyang and Chen, Xie and Choi, Eunsol and Harwath, David},
title = {BAT: Learning to Reason about Spatial Sounds with Large Language Models},
journal = {arXiv preprint arXiv:2402.01591},
year = {2024},
}
```
Binary file added examples/seld_spatialsoundqa/assets/bat.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
19 changes: 19 additions & 0 deletions examples/seld_spatialsoundqa/conf/ds_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
{
"train_micro_batch_size_per_gpu": 4,
"gradient_accumulation_steps": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 1e-4
}
},
"fp16": {
"enabled": true
},
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu"
}
}
}
152 changes: 152 additions & 0 deletions examples/seld_spatialsoundqa/dataset/spatial_audio_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import os
import random
import json
import copy

import numpy as np
import soundfile as sf
from scipy import signal

import torch

from slam_llm.datasets.base_dataset import BaseDataset

def format_prompt(instruction, input=None):
PROMPT_DICT = {
"prompt_input": (
"Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
),
"prompt_no_input": (
"Based on the audio you've heard, refer to the instruction and provide a response.\n\n"
"### Instruction:\n{instruction}\n\n### Response:"
),
}
if input is None:
return PROMPT_DICT['prompt_no_input'].format_map({'instruction': instruction})
else:
return PROMPT_DICT["prompt_input"].format_map({'instruction': instruction, 'input': input})


class SpatialAudioDatasetJsonl(BaseDataset):
def __init__(
self,
dataset_config,
tokenizer,
split,
):
super().__init__()
dataset_path = os.path.join(dataset_config['qa_data_root'], dataset_config['stage'], split + '.jsonl')
with open(dataset_path) as f:
self.data = [json.loads(line) for line in f.readlines()]

self.anechoic_data_root = dataset_config['anechoic_data_root'] # which is AudioSet in this case
self.reverb_data_root = dataset_config['reverb_data_root']
self.channel_type = dataset_config['channel_type']

self.ext_audio = dataset_config['ext_audio']
self.max_words = dataset_config['max_words']
self.fix_length_audio = dataset_config.get("fix_length_audio", -1)

self.tokenizer = tokenizer

self.normalize = dataset_config['normalize']
self.inference_mode = dataset_config['inference_mode']

def __len__(self):
return len(self.data)

def __getitem__(self, index):
sample = self.data[index]

audio_path = os.path.join(self.anechoic_data_root, sample['audio_id'] + self.ext_audio)
reverb_path = os.path.join(self.reverb_data_root, self.channel_type, sample['reverb_id'])

if sample['audio_id2'] is not None and sample['reverb_id2'] is not None:
audio_path2 = os.path.join(self.anechoic_data_root, sample['audio_id2'] + self.ext_audio)
reverb_path2 = os.path.join(self.reverb_data_root, self.channel_type, sample['reverb_id2'])
else:
audio_path2 = None
reverb_path2 = None

waveforms = self.load_waveform(audio_path, reverb_path, audio_path2, reverb_path2)

prompt = sample['question']
prompt = format_prompt(prompt, None)
answer = sample['answer']

if not self.inference_mode:
return super().__getitem__((waveforms, None, prompt, answer))
else:
base_sample = super().__getitem__((waveforms, None, prompt, answer))
base_sample.update({
"key": f"{sample['question_type']}-{sample['question_id']}",
"target": sample['answer']
})
return base_sample

@classmethod
def normalize_audio(cls, audio_data, target_dBFS=-14.0):
rms = np.sqrt(np.mean(audio_data**2)) # Calculate the RMS of the audio

if rms == 0: # Avoid division by zero in case of a completely silent audio
return audio_data

current_dBFS = 20 * np.log10(rms) # Convert RMS to dBFS
gain_dB = target_dBFS - current_dBFS # Calculate the required gain in dB
gain_linear = 10 ** (gain_dB / 20) # Convert gain from dB to linear scale
normalized_audio = audio_data * gain_linear # Apply the gain to the audio data
return normalized_audio

@classmethod
def load_waveform(cls, audio_path, reverb_path=None, audio_path2=None, reverb_path2=None, normalize=True):
waveform, sr = sf.read(audio_path)

if len(waveform.shape) > 1:
waveform = waveform[:, 0]
if sr != 32000:
waveform = signal.resample_poly(waveform, 32000, sr)
sr = 32000
if normalize:
waveform = cls.normalize_audio(waveform, -14.0)

waveform = waveform.reshape(1, -1)
if reverb_path is not None:
reverb = np.load(reverb_path)
waveform = signal.fftconvolve(waveform, reverb, mode='full')

waveform = torch.from_numpy(waveform).float()
waveform = cls.padding(waveform, padding_length=10*sr-waveform.shape[1])

if audio_path2 is not None and reverb_path2 is not None:
waveform2, sr2 = sf.read(audio_path2)

if len(waveform2.shape) > 1:
waveform2 = waveform2[:, 0]
if sr2 != 32000:
waveform2 = signal.resample_poly(waveform2, 32000, sr2)
sr2 = 32000
if normalize:
waveform2 = cls.normalize_audio(waveform2, -14.0)

waveform2 = waveform2.reshape(1, -1)
reverb2 = np.load(reverb_path2)
waveform2 = signal.fftconvolve(waveform2, reverb2, mode='full')
waveform2 = torch.from_numpy(waveform2).float()
waveform2 = cls.padding(waveform2, padding_length=10*sr-waveform2.shape[1])

waveform = (waveform + waveform2) / 2
return waveform

def collator(self, samples):
audio = torch.stack([s['audio'] for s in samples])

collated = super().collator(samples)
collated['audio'] = audio

return collated

def get_spatial_audio_dataset(dataset_config, tokenizer, split):
dataset = SpatialAudioDatasetJsonl(dataset_config, tokenizer, split)
return dataset
48 changes: 48 additions & 0 deletions examples/seld_spatialsoundqa/finetune_seld.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import hydra
import logging
from dataclasses import dataclass, field
from omegaconf import DictConfig, ListConfig, OmegaConf

from seld_config import ModelConfig, TrainConfig, DataConfig, LogConfig, FSDPConfig, PeftConfig
from slam_llm.pipeline.finetune import main as train

@dataclass
class RunConfig:
dataset_config: DataConfig = field(default_factory=DataConfig)
model_config: ModelConfig = field(default_factory=ModelConfig)
train_config: TrainConfig = field(default_factory=TrainConfig)
log_config: LogConfig = field(default_factory=LogConfig)
fsdp_config: FSDPConfig = field(default_factory=FSDPConfig)
peft_config: PeftConfig = field(default_factory=PeftConfig)
debug: bool = field(default=False, metadata={"help": "Use pdb when true"})
metric: str = field(default="acc", metadata={"help": "The metric for evaluation"})
ckpt_path: str = field(
default="output/model.pt", metadata={"help": "The path to projector checkpoint"}
)

@hydra.main(config_name=None, version_base=None)
def main_hydra(cfg: DictConfig):
run_config = RunConfig()
cfg = OmegaConf.merge(run_config, cfg)
def to_plain_list(cfg_item):
if isinstance(cfg_item, ListConfig):
return OmegaConf.to_container(cfg_item, resolve=True)
elif isinstance(cfg_item, DictConfig):
return {k: to_plain_list(v) for k, v in cfg_item.items()}
else:
return cfg_item

# kwargs = to_plain_list(cfg)
kwargs = cfg
log_level = getattr(logging, kwargs.get("log_level", "INFO").upper())
logging.basicConfig(level=log_level)

if kwargs.get("debug", False):
import pdb;
pdb.set_trace()

train(kwargs)


if __name__ == "__main__":
main_hydra()
53 changes: 53 additions & 0 deletions examples/seld_spatialsoundqa/inference_seld_batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import hydra
import logging
from dataclasses import dataclass, field
from omegaconf import DictConfig, ListConfig, OmegaConf
from typing import Optional

from slam_llm.pipeline.inference_batch import main as inference
from seld_config import ModelConfig, TrainConfig, DataConfig, LogConfig, FSDPConfig, PeftConfig

@dataclass
class RunConfig:
dataset_config: DataConfig = field(default_factory=DataConfig)
model_config: ModelConfig = field(default_factory=ModelConfig)
train_config: TrainConfig = field(default_factory=TrainConfig)
log_config: LogConfig = field(default_factory=LogConfig)
fsdp_config: FSDPConfig = field(default_factory=FSDPConfig)
peft_config: PeftConfig = field(default_factory=PeftConfig)
debug: bool = field(default=False, metadata={"help": "Use pdb when true"})
metric: str = field(default="acc", metadata={"help": "The metric for evaluation"})
decode_log: str = field(
default="output/decode_log",
metadata={"help": "The prefix for the decode output"},
)
ckpt_path: str = field(
default="output/model.pt", metadata={"help": "The path to projector checkpoint"}
)
peft_ckpt: Optional[str] = field(
default=None,
metadata={
"help": "The path to peft checkpoint, should be a directory including adapter_config.json"
},
)


@hydra.main(config_name=None, version_base=None)
def main_hydra(cfg: DictConfig):
run_config = RunConfig()
cfg = OmegaConf.merge(run_config, cfg)
# kwargs = to_plain_list(cfg)
log_level = getattr(logging, cfg.get("log_level", "INFO").upper())

logging.basicConfig(level=log_level)

if cfg.get("debug", False):
import pdb

pdb.set_trace()

inference(cfg)


if __name__ == "__main__":
main_hydra()
Loading

0 comments on commit c89b620

Please sign in to comment.