Skip to content

Commit

Permalink
Merge pull request #1 from neuralaudio/baseline-model
Browse files Browse the repository at this point in the history
Baseline model
  • Loading branch information
jorshi authored Jul 10, 2021
2 parents 82220e1 + 8c526f6 commit 5c0674e
Show file tree
Hide file tree
Showing 10 changed files with 539 additions and 2 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,6 @@ dmypy.json

# Pyre type checker
.pyre/

# PyCharm
.idea/
8 changes: 8 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
repos:
# Black is a code formatter that ensures consistent code styling
# The pre-commit hook runs it before a commit to ensure that code conforms
- repo: https://github.com/psf/black
rev: 21.5b0
hooks:
- id: black
language_version: python3
66 changes: 64 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,64 @@
# hear-baseline
Simple baseline model for the HEAR 2021 NeurIPS competition
![HEAR2021](https://neuralaudio.ai/assets/img/hear-header-sponsor.jpg)
# HEAR 2021 Baseline

A simple DSP-based audio embedding consisting of a Mel-frequency spectrogram followed
by a random projection. Serves as the naive baseline model for the HEAR 2021 and implements
the [common API](https://neuralaudio.ai/hear2021-holistic-evaluation-of-audio-representations.html#common-api)
required by the competition evaluation.

For full details on the HEAR 2021 NeurIPS competition and for information on how to
participate, please visit the
[competition website.](https://neuralaudio.ai/hear2021-holistic-evaluation-of-audio-representations.html)

### Installation

**Method 1: pypi**
```python
pip install hearbaseline
```

**Method 2: pip local source tree**

This is the same method that will be used to by competition organizers when installing
submissions to HEAR 2021.
```python
git clone https://github.com/neuralaudio/hear-baseline.git
python3 -m pip install ./hear-baseline
```

### Naive Baseline Model
The naive baseline model produces log-scaled Mel-frequency spectrograms using a
256-band Mel filter. Each frame of the spectrogram is then projected to 4096
dimensions using a random projection matrix. Weights for the projection matrix were
generated by sampling a normal distribution and are stored in this repository in the
file `saved_models/naive_baseline.pt`.

Using a random projection is less efficient
than a CNN but is one of the simplest models to implement from a coding perspective.

### Usage

Audio embeddings can be computed using one of two methods: 1)
`get_scene_embeddings`, or 2) `get_timestamp_embeddings`.

`get_scene_embeddings` accepts a batch of audio clips and produces a single embedding
for each audio clip. This can be computed like so:
```python
import torch
import hearbaseline

# Load model with weights - located in the root directory of this repo
model = hearbaseline.load_model("saved_models/naive_baseline.pt")

# Create a batch of 2 white noise clips that are 2-seconds long
# and compute scene embeddings for each clip
audio = torch.rand((2, model.sample_rate * 2))
embeddings = hearbaseline.get_scene_embeddings(audio, model)
```

The `get_timestamp_embeddings` method works exactly the same but returns an array
of embeddings computed every 25ms over the duration of the input audio. An array
of timestamps corresponding to each embedding is also returned.

See the [common API](https://neuralaudio.ai/hear2021-holistic-evaluation-of-audio-representations.html#common-api)
for more details.
1 change: 1 addition & 0 deletions hearbaseline/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .naive import load_model, get_scene_embeddings, get_timestamp_embeddings
189 changes: 189 additions & 0 deletions hearbaseline/naive.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
"""
Baseline model for HEAR 2021 NeurIPS competition.
This is simply a mel spectrogram followed by random projection.
"""

from collections import OrderedDict
import math
from typing import Tuple

import librosa
import torch
from torch import Tensor

from hearbaseline.util import frame_audio

# Default hop_size in milliseconds
HOP_SIZE = 25

# Number of frames to batch process for timestamp embeddings
BATCH_SIZE = 512


class RandomProjectionMelEmbedding(torch.nn.Module):
"""
Baseline audio embedding model. This model creates mel frequency spectrums with
256 mel-bands, and then performs a projection to an embedding size of 4096.
"""

# sample rate and embedding sizes are required model attributes for the HEAR API
sample_rate = 44100
embedding_size = 4096
scene_embedding_size = embedding_size
timestamp_embedding_size = embedding_size

# These attributes are specific to this baseline model
n_fft = 4096
n_mels = 256
seed = 0
epsilon = 1e-4

def __init__(self):
super().__init__()
torch.random.manual_seed(self.seed)

# Create a Hann window buffer to apply to frames prior to FFT.
self.register_buffer("window", torch.hann_window(self.n_fft))

# Create a mel filter buffer.
mel_scale: Tensor = torch.tensor(
librosa.filters.mel(self.sample_rate, n_fft=self.n_fft, n_mels=self.n_mels)
)
self.register_buffer("mel_scale", mel_scale)

# Projection matrices.
normalization = math.sqrt(self.n_mels)
self.projection = torch.nn.Parameter(
torch.rand(self.n_mels, self.embedding_size) / normalization
)

def forward(self, x: Tensor):
# Compute the real-valued Fourier transform on windowed input signal.
x = torch.fft.rfft(x * self.window)

# Convert to a power spectrum.
x = torch.abs(x) ** 2.0

# Apply the mel-scale filter to the power spectrum.
x = torch.matmul(x, self.mel_scale.transpose(0, 1))

# Convert to a log mel spectrum.
x = torch.log(x + self.epsilon)

# Apply projection to get a 4096 dimension embedding
embedding = x.matmul(self.projection)

return embedding


def load_model(model_file_path: str = "") -> torch.nn.Module:
"""
Returns a torch.nn.Module that produces embeddings for a frame of audio.
Args:
model_file_path: Load model checkpoint from this file path. For this baseline,
if no path is provided then the default random init weights for the
linear projection layer will be used.
Returns:
Model: torch.nn.Module loaded on the specified device.
"""
model = RandomProjectionMelEmbedding()
if model_file_path != "":
loaded_model = torch.load(model_file_path)
if not isinstance(loaded_model, OrderedDict):
raise TypeError(
f"Loaded model must be a model state dict of type OrderedDict. "
f"Received {type(loaded_model)}"
)

model.load_state_dict(loaded_model)

return model


def get_timestamp_embeddings(
audio: Tensor,
model: torch.nn.Module,
) -> Tuple[Tensor, Tensor]:
"""
This function returns embeddings at regular intervals centered at timestamps. Both
the embeddings and corresponding timestamps (in milliseconds) are returned.
Args:
audio: n_sounds x n_samples of mono audio in the range [-1, 1].
model: Loaded model.
Returns:
- Tensor: embeddings, A float32 Tensor with shape (n_sounds, n_timestamp,
model.timestamp_embedding_size).
- Tensor: timestamps, Centered timestamps in milliseconds corresponding
to each embedding in the output.
"""

# Assert audio is of correct shape
if audio.ndim != 2:
raise ValueError(
"audio input tensor must be 2D with shape (batch_size, num_samples)"
)

# Make sure the correct model type was passed in
if not isinstance(model, RandomProjectionMelEmbedding):
raise ValueError(
f"Model must be an instance of {RandomProjectionMelEmbedding.__name__}"
)

# Send the model to the same device that the audio tensor is on.
model = model.to(audio.device)

# Split the input audio signals into frames and then flatten to create a tensor
# of audio frames that can be batch processed.
frames, timestamps = frame_audio(
audio,
frame_size=model.n_fft,
hop_size=HOP_SIZE,
sample_rate=RandomProjectionMelEmbedding.sample_rate,
)
audio_batches, num_frames, frame_size = frames.shape
frames = frames.flatten(end_dim=1)

# We're using a DataLoader to help with batching of frames
dataset = torch.utils.data.TensorDataset(frames)
loader = torch.utils.data.DataLoader(
dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=False
)

# Put the model into eval mode, and not computing gradients while in inference.
# Iterate over all batches and accumulate the embeddings for each frame.
model.eval()
with torch.no_grad():
embeddings_list = [model(batch[0]) for batch in loader]

# Concatenate mini-batches back together and unflatten the frames
# to reconstruct the audio batches
embeddings = torch.cat(embeddings_list, dim=0)
embeddings = embeddings.unflatten(0, (audio_batches, num_frames))

return embeddings, timestamps


def get_scene_embeddings(
audio: Tensor,
model: torch.nn.Module,
) -> Tensor:
"""
This function returns a single embedding for each audio clip. In this baseline
implementation we simply summarize the temporal embeddings from
get_timestamp_embeddings() using torch.mean().
Args:
audio: n_sounds x n_samples of mono audio in the range [-1, 1]. All sounds in
a batch will be padded/trimmed to the same length.
model: Loaded model.
Returns:
- embeddings, A float32 Tensor with shape (n_sounds, model.scene_embedding_size).
"""
embeddings, _ = get_timestamp_embeddings(audio, model)
embeddings = torch.mean(embeddings, dim=1)
return embeddings
56 changes: 56 additions & 0 deletions hearbaseline/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""
Utility functions for hear-kit
"""

from typing import Tuple

import torch
import torch.nn.functional as F
from torch import Tensor


def frame_audio(
audio: Tensor, frame_size: int, hop_size: float, sample_rate: int
) -> Tuple[Tensor, Tensor]:
"""
Slices input audio into frames that are centered and occur every
sample_rate * hop_size samples. We round to the nearest sample.
Args:
audio: input audio, expects a 2d Tensor of shape:
(batch_size, num_samples)
frame_size: the number of samples each resulting frame should be
hop_size: hop size between frames, in milliseconds
sample_rate: sampling rate of the input audio
Returns:
- A Tensor of shape (batch_size, num_frames, frame_size)
- A 1d Tensor of timestamps corresponding to the frame
centers.
"""

# Zero pad the beginning and the end of the incoming audio with half a frame number
# of samples. This centers the audio in the middle of each frame with respect to
# the timestamps.
audio = F.pad(audio, (frame_size // 2, frame_size - frame_size // 2))
num_padded_samples = audio.shape[1]

frame_number = 0
frames = []
timestamps = []
frame_start = 0
frame_end = frame_size
while True:
frames.append(audio[:, frame_start:frame_end])
timestamps.append(frame_number * hop_size)

# Increment the frame_number and break the loop if the next frame end
# will extend past the end of the padded audio samples
frame_number += 1
frame_start = int(round(sample_rate * frame_number * hop_size / 1000))
frame_end = frame_start + frame_size

if not frame_end <= num_padded_samples:
break

return torch.stack(frames, dim=1), torch.tensor(timestamps)
Binary file added saved_models/naive_baseline.pt
Binary file not shown.
37 changes: 37 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#!/usr/bin/env python3
from setuptools import find_packages, setup

long_description = open("README.md", "r", encoding="utf-8").read()

setup(
name="hearbaseline",
version="2021.0.1",
description="Holistic Evaluation of Audio Representations (HEAR) 2021 -- Baseline Model",
author="HEAR 2021 NeurIPS Competition Committee",
author_email="deep-at-neuralaudio.ai",
url="https://github.com/neuralaudio/hear-baseline",
license="Apache-2.0",
long_description=long_description,
long_description_content_type="text/markdown",
project_urls={
"Bug Tracker": "https://github.com/neuralaudio/hear-baseline/issues",
"Source Code": "https://github.com/neuralaudio/hear-baseline",
},
packages=find_packages(exclude=("tests",)),
python_requires=">=3.6",
install_requires=["librosa", "torch"],
extras_require={
"test": [
"pytest",
"pytest-cov",
"pytest-env",
],
"dev": [
"pre-commit",
"black", # Used in pre-commit hooks
"pytest",
"pytest-cov",
"pytest-env",
],
},
)
Empty file added tests/__init__.py
Empty file.
Loading

0 comments on commit 5c0674e

Please sign in to comment.