Skip to content

Commit

Permalink
Merge pull request #9 from stevenguh/feature/cuda
Browse files Browse the repository at this point in the history
Feature/cuda
  • Loading branch information
Harri Taylor authored Feb 25, 2020
2 parents de6d377 + 48521ec commit 51bab6b
Showing 1 changed file with 64 additions and 17 deletions.
81 changes: 64 additions & 17 deletions torchvggish/vggish.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import numpy as np
import torch
import torch.nn as nn
from torch import hub
import numpy as np
from . import vggish_input

from . import vggish_input, vggish_params


class VGG(nn.Module):
Expand Down Expand Up @@ -30,7 +31,7 @@ def forward(self, x):
return self.embeddings(x)


class Postprocessor(object):
class Postprocessor(nn.Module):
"""Post-processes VGGish embeddings. Returns a torch.Tensor instead of a
numpy array in order to preserve the gradient.
Expand All @@ -43,11 +44,20 @@ class Postprocessor(object):
the same PCA (with whitening) and quantization transformations."
"""

def __init__(self, params):
def __init__(self):
"""Constructs a postprocessor."""
params = hub.load_state_dict_from_url(params)
self._pca_matrix = torch.as_tensor(params["pca_eigen_vectors"]).float()
self._pca_means = torch.as_tensor(params["pca_means"].reshape(-1, 1)).float()
super(Postprocessor, self).__init__()
# Create empty matrix, for user's state_dict to load
self.pca_eigen_vectors = torch.empty(
(vggish_params.EMBEDDING_SIZE, vggish_params.EMBEDDING_SIZE,),
dtype=torch.float,
)
self.pca_means = torch.empty(
(vggish_params.EMBEDDING_SIZE, 1), dtype=torch.float
)

self.pca_eigen_vectors = nn.Parameter(self.pca_eigen_vectors, requires_grad=False)
self.pca_means = nn.Parameter(self.pca_means, requires_grad=False)

def postprocess(self, embeddings_batch):
"""Applies tensor postprocessing to a batch of embeddings.
Expand All @@ -60,15 +70,40 @@ def postprocess(self, embeddings_batch):
A tensor of the same shape as the input, containing the PCA-transformed,
quantized, and clipped version of the input.
"""
pca_applied = torch.mm(
self._pca_matrix, (embeddings_batch.t() - self._pca_means)
).t()
clipped_embeddings = torch.clamp(pca_applied, -2.0, +2.0)
assert len(embeddings_batch.shape) == 2, "Expected 2-d batch, got %r" % (
embeddings_batch.shape,
)
assert (
embeddings_batch.shape[1] == vggish_params.EMBEDDING_SIZE
), "Bad batch shape: %r" % (embeddings_batch.shape,)

# Apply PCA.
# - Embeddings come in as [batch_size, embedding_size].
# - Transpose to [embedding_size, batch_size].
# - Subtract pca_means column vector from each column.
# - Premultiply by PCA matrix of shape [output_dims, input_dims]
# where both are are equal to embedding_size in our case.
# - Transpose result back to [batch_size, embedding_size].
pca_applied = torch.mm(self.pca_eigen_vectors, (embeddings_batch.T - self.pca_means)).T

# Quantize by:
# - clipping to [min, max] range
clipped_embeddings = torch.clamp(
pca_applied, vggish_params.QUANTIZE_MIN_VAL, vggish_params.QUANTIZE_MAX_VAL
)
# - convert to 8-bit in range [0.0, 255.0]
quantized_embeddings = torch.round(
(clipped_embeddings - -2.0) * (255.0 / (+2.0 - -2.0))
(clipped_embeddings - vggish_params.QUANTIZE_MIN_VAL)
* (
255.0
/ (vggish_params.QUANTIZE_MAX_VAL - vggish_params.QUANTIZE_MIN_VAL)
)
)
return torch.squeeze(quantized_embeddings)

def forward(self, x):
return self.postprocess(x)


def make_layers():
layers = []
Expand Down Expand Up @@ -106,15 +141,27 @@ def _vgg():


class VGGish(VGG):
def __init__(self, urls, trained=True, preprocess=True, postprocess=True):
def __init__(self, urls, pretrained=True, preprocess=True, postprocess=True, progress=True):
super().__init__(make_layers())
if trained:
state_dict = hub.load_state_dict_from_url(urls['vggish'], progress=True)
if pretrained:
state_dict = hub.load_state_dict_from_url(urls['vggish'], progress=progress)
super().load_state_dict(state_dict)

self.preprocess = preprocess
self.postprocess = postprocess
self.pproc = Postprocessor(urls['pca'])
if self.preprocess:
self.pproc = Postprocessor()
if pretrained:
state_dict = hub.load_state_dict_from_url(urls['pca'], progress=progress)
# TODO: Convert the state_dict to torch
state_dict[vggish_params.PCA_EIGEN_VECTORS_NAME] = torch.as_tensor(
state_dict[vggish_params.PCA_EIGEN_VECTORS_NAME], dtype=torch.float
)
state_dict[vggish_params.PCA_MEANS_NAME] = torch.as_tensor(
state_dict[vggish_params.PCA_MEANS_NAME].reshape(-1, 1), dtype=torch.float
)

self.pproc.load_state_dict(state_dict)

def forward(self, x, fs=None):
if self.preprocess:
Expand All @@ -134,4 +181,4 @@ def _preprocess(self, x, fs):
return x

def _postprocess(self, x):
return self.pproc.postprocess(x)
return self.pproc(x)

0 comments on commit 51bab6b

Please sign in to comment.