Skip to content

Commit

Permalink
Merge pull request #10 from neuralaudio/wav2vec2
Browse files Browse the repository at this point in the history
Wav2vec2 HEAR
  • Loading branch information
turian authored Aug 26, 2021
2 parents b941adc + c1a2ea5 commit 78bc153
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 2 deletions.
31 changes: 31 additions & 0 deletions .github/workflows/validation-wav2vec2.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
name: hearvalidator on wav2vec2

on: [pull_request]

jobs:
build:

runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.7, 3.8, 3.9]

steps:
- uses: actions/checkout@master
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@master
with:
python-version: ${{ matrix.python-version }}
- name: apt-get
run: |
sudo apt-get install -y libsndfile-dev
- name: Display Python version
run: python -c "import sys; print(sys.version)"
- name: python dependencies
run: |
python -m pip install --upgrade pip
pip install -e .
pip install hearvalidator
- name: Validate
run: |
hear-validator hearbaseline.wav2vec2
2 changes: 1 addition & 1 deletion .github/workflows/validation.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: HEAR Validator on Baselines
name: hearvalidator on baselines

on: [pull_request]

Expand Down
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -129,4 +129,6 @@ dmypy.json
.pyre/

# PyCharm
.idea/
.idea/

pretrained/
135 changes: 135 additions & 0 deletions hearbaseline/wav2vec2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
"""
wav2vec2 model for HEAR 2021 NeurIPS competition.
Adapted from
https://colab.research.google.com/drive/17Hu1pxqhfMisjkSgmM2CnZxfqDyn2hSY?usp=sharing
"""

from typing import Tuple

import torch

# from speechbrain.lobes.models.fairseq_wav2vec import FairseqWav2Vec2
from speechbrain.lobes.models.huggingface_wav2vec import HuggingFaceWav2Vec2
from torch import Tensor

# HuggingFace model hub
model_hub = "facebook/wav2vec2-base-960h"

# Faiseq model url
# model_url = "https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_small.pt"


def load_model(model_file_path: str = "") -> torch.nn.Module:
"""
Returns a torch.nn.Module that produces embeddings for a frame of audio.
Args:
model_file_path: Load model checkpoint from this file path. For this baseline,
if no path is provided then the default random init weights for the
linear projection layer will be used.
Returns:
Model: torch.nn.Module loaded on the specified device.
"""
# model_fairseq = FairseqWav2Vec2(model_url, save_path="pretrained/local_model.pt")
model_huggingface = HuggingFaceWav2Vec2(model_hub, save_path="pretrained/")
model = model_huggingface
if torch.cuda.is_available():
model.cuda()

# sample rate and embedding sizes are required model attributes for the HEAR API
model.sample_rate = 16000
model.embedding_size = 768
model.scene_embedding_size = model.embedding_size
model.timestamp_embedding_size = model.embedding_size

return model


def get_timestamp_embeddings(
audio: Tensor,
model: torch.nn.Module,
) -> Tuple[Tensor, Tensor]:
"""
This function returns embeddings at regular intervals centered at timestamps. Both
the embeddings and corresponding timestamps (in milliseconds) are returned.
Args:
audio: n_sounds x n_samples of mono audio in the range [-1, 1].
model: Loaded model.
Returns:
- Tensor: embeddings, A float32 Tensor with shape (n_sounds, n_timestamps,
model.timestamp_embedding_size).
- Tensor: timestamps, Centered timestamps in milliseconds corresponding
to each embedding in the output. Shape: (n_sounds, n_timestamps).
"""

# Assert audio is of correct shape
if audio.ndim != 2:
raise ValueError(
"audio input tensor must be 2D with shape (n_sounds, num_samples)"
)

# Make sure the correct model type was passed in
if not isinstance(model, HuggingFaceWav2Vec2):
raise ValueError(f"Model must be an instance of {HuggingFaceWav2Vec2.__name__}")

# Send the model to the same device that the audio tensor is on.
# model = model.to(audio.device)

# Put the model into eval mode, and not computing gradients while in inference.
# Iterate over all batches and accumulate the embeddings for each frame.
model.eval()
with torch.no_grad():
embeddings = model(audio)

# Length of the audio in MS
audio_ms = int(audio.shape[1] / model.sample_rate * 1000)

# samples => timestamps
# 31439 => 97
# 31440 => 98
# This is weird that its 5ms, not half the hopsize of 20
ntimestamps = (audio_ms - 5) // 20

# Also
# 32000 => 99
# 32080 => 100

# I don't know if this is their exact centering, but this matches
# their shape.
last_center = 12.5 + (ntimestamps - 1) * 20
timestamps = torch.arange(12.5, last_center + 20, 20)
assert len(timestamps) == ntimestamps
timestamps = timestamps.expand((embeddings.shape[0], timestamps.shape[0]))
assert timestamps.shape[1] == embeddings.shape[1]
timestamps = torch.zeros((embeddings.shape[0], embeddings.shape[1]))

return embeddings, timestamps


# TODO: There must be a better way to do scene embeddings,
# e.g. just truncating / padding the audio to 2 seconds
# and concatenating a subset of the embeddings.
def get_scene_embeddings(
audio: Tensor,
model: torch.nn.Module,
) -> Tensor:
"""
This function returns a single embedding for each audio clip. In this baseline
implementation we simply summarize the temporal embeddings from
get_timestamp_embeddings() using torch.mean().
Args:
audio: n_sounds x n_samples of mono audio in the range [-1, 1]. All sounds in
a batch will be padded/trimmed to the same length.
model: Loaded model.
Returns:
- embeddings, A float32 Tensor with shape
(n_sounds, model.scene_embedding_size).
"""
embeddings, _ = get_timestamp_embeddings(audio, model)
embeddings = torch.mean(embeddings, dim=1)
return embeddings
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
"numpy==1.19.2",
"tensorflow",
"torch",
# For wav2vec2 model
"speechbrain",
"transformers==4.4.0"
],
extras_require={
"test": [
Expand Down

0 comments on commit 78bc153

Please sign in to comment.