-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #10 from neuralaudio/wav2vec2
Wav2vec2 HEAR
- Loading branch information
Showing
5 changed files
with
173 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
name: hearvalidator on wav2vec2 | ||
|
||
on: [pull_request] | ||
|
||
jobs: | ||
build: | ||
|
||
runs-on: ubuntu-latest | ||
strategy: | ||
matrix: | ||
python-version: [3.7, 3.8, 3.9] | ||
|
||
steps: | ||
- uses: actions/checkout@master | ||
- name: Set up Python ${{ matrix.python-version }} | ||
uses: actions/setup-python@master | ||
with: | ||
python-version: ${{ matrix.python-version }} | ||
- name: apt-get | ||
run: | | ||
sudo apt-get install -y libsndfile-dev | ||
- name: Display Python version | ||
run: python -c "import sys; print(sys.version)" | ||
- name: python dependencies | ||
run: | | ||
python -m pip install --upgrade pip | ||
pip install -e . | ||
pip install hearvalidator | ||
- name: Validate | ||
run: | | ||
hear-validator hearbaseline.wav2vec2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
name: HEAR Validator on Baselines | ||
name: hearvalidator on baselines | ||
|
||
on: [pull_request] | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -129,4 +129,6 @@ dmypy.json | |
.pyre/ | ||
|
||
# PyCharm | ||
.idea/ | ||
.idea/ | ||
|
||
pretrained/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
""" | ||
wav2vec2 model for HEAR 2021 NeurIPS competition. | ||
Adapted from | ||
https://colab.research.google.com/drive/17Hu1pxqhfMisjkSgmM2CnZxfqDyn2hSY?usp=sharing | ||
""" | ||
|
||
from typing import Tuple | ||
|
||
import torch | ||
|
||
# from speechbrain.lobes.models.fairseq_wav2vec import FairseqWav2Vec2 | ||
from speechbrain.lobes.models.huggingface_wav2vec import HuggingFaceWav2Vec2 | ||
from torch import Tensor | ||
|
||
# HuggingFace model hub | ||
model_hub = "facebook/wav2vec2-base-960h" | ||
|
||
# Faiseq model url | ||
# model_url = "https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_small.pt" | ||
|
||
|
||
def load_model(model_file_path: str = "") -> torch.nn.Module: | ||
""" | ||
Returns a torch.nn.Module that produces embeddings for a frame of audio. | ||
Args: | ||
model_file_path: Load model checkpoint from this file path. For this baseline, | ||
if no path is provided then the default random init weights for the | ||
linear projection layer will be used. | ||
Returns: | ||
Model: torch.nn.Module loaded on the specified device. | ||
""" | ||
# model_fairseq = FairseqWav2Vec2(model_url, save_path="pretrained/local_model.pt") | ||
model_huggingface = HuggingFaceWav2Vec2(model_hub, save_path="pretrained/") | ||
model = model_huggingface | ||
if torch.cuda.is_available(): | ||
model.cuda() | ||
|
||
# sample rate and embedding sizes are required model attributes for the HEAR API | ||
model.sample_rate = 16000 | ||
model.embedding_size = 768 | ||
model.scene_embedding_size = model.embedding_size | ||
model.timestamp_embedding_size = model.embedding_size | ||
|
||
return model | ||
|
||
|
||
def get_timestamp_embeddings( | ||
audio: Tensor, | ||
model: torch.nn.Module, | ||
) -> Tuple[Tensor, Tensor]: | ||
""" | ||
This function returns embeddings at regular intervals centered at timestamps. Both | ||
the embeddings and corresponding timestamps (in milliseconds) are returned. | ||
Args: | ||
audio: n_sounds x n_samples of mono audio in the range [-1, 1]. | ||
model: Loaded model. | ||
Returns: | ||
- Tensor: embeddings, A float32 Tensor with shape (n_sounds, n_timestamps, | ||
model.timestamp_embedding_size). | ||
- Tensor: timestamps, Centered timestamps in milliseconds corresponding | ||
to each embedding in the output. Shape: (n_sounds, n_timestamps). | ||
""" | ||
|
||
# Assert audio is of correct shape | ||
if audio.ndim != 2: | ||
raise ValueError( | ||
"audio input tensor must be 2D with shape (n_sounds, num_samples)" | ||
) | ||
|
||
# Make sure the correct model type was passed in | ||
if not isinstance(model, HuggingFaceWav2Vec2): | ||
raise ValueError(f"Model must be an instance of {HuggingFaceWav2Vec2.__name__}") | ||
|
||
# Send the model to the same device that the audio tensor is on. | ||
# model = model.to(audio.device) | ||
|
||
# Put the model into eval mode, and not computing gradients while in inference. | ||
# Iterate over all batches and accumulate the embeddings for each frame. | ||
model.eval() | ||
with torch.no_grad(): | ||
embeddings = model(audio) | ||
|
||
# Length of the audio in MS | ||
audio_ms = int(audio.shape[1] / model.sample_rate * 1000) | ||
|
||
# samples => timestamps | ||
# 31439 => 97 | ||
# 31440 => 98 | ||
# This is weird that its 5ms, not half the hopsize of 20 | ||
ntimestamps = (audio_ms - 5) // 20 | ||
|
||
# Also | ||
# 32000 => 99 | ||
# 32080 => 100 | ||
|
||
# I don't know if this is their exact centering, but this matches | ||
# their shape. | ||
last_center = 12.5 + (ntimestamps - 1) * 20 | ||
timestamps = torch.arange(12.5, last_center + 20, 20) | ||
assert len(timestamps) == ntimestamps | ||
timestamps = timestamps.expand((embeddings.shape[0], timestamps.shape[0])) | ||
assert timestamps.shape[1] == embeddings.shape[1] | ||
timestamps = torch.zeros((embeddings.shape[0], embeddings.shape[1])) | ||
|
||
return embeddings, timestamps | ||
|
||
|
||
# TODO: There must be a better way to do scene embeddings, | ||
# e.g. just truncating / padding the audio to 2 seconds | ||
# and concatenating a subset of the embeddings. | ||
def get_scene_embeddings( | ||
audio: Tensor, | ||
model: torch.nn.Module, | ||
) -> Tensor: | ||
""" | ||
This function returns a single embedding for each audio clip. In this baseline | ||
implementation we simply summarize the temporal embeddings from | ||
get_timestamp_embeddings() using torch.mean(). | ||
Args: | ||
audio: n_sounds x n_samples of mono audio in the range [-1, 1]. All sounds in | ||
a batch will be padded/trimmed to the same length. | ||
model: Loaded model. | ||
Returns: | ||
- embeddings, A float32 Tensor with shape | ||
(n_sounds, model.scene_embedding_size). | ||
""" | ||
embeddings, _ = get_timestamp_embeddings(audio, model) | ||
embeddings = torch.mean(embeddings, dim=1) | ||
return embeddings |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters