Skip to content

Commit

Permalink
fixed trill feature extraction
Browse files Browse the repository at this point in the history
  • Loading branch information
felixbur committed Sep 14, 2021
1 parent c22b3fb commit 6d30a4f
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 22 deletions.
9 changes: 5 additions & 4 deletions exp_emodb.ini
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@ runs = 1
epochs = 1
[DATA]
databases = ['emodb']
emodb = /home/fburkhardt/audb/emodb/6/
#emodb.split_strategy = reuse
emodb.split_strategy = speaker_split
emodb = /home/felix/data/audb/emodb/
emodb.split_strategy = reuse
#emodb.split_strategy = speaker_split
emodb.testsplit = 40
target = emotion
labels = ['anger', 'boredom', 'disgust', 'fear', 'happiness', 'neutral', 'sadness']
[FEATS]
type = os
#type = os
type = trill
[MODEL]
type = xgb
#tuning_params = ['C']
Expand Down
26 changes: 18 additions & 8 deletions src/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from runmanager import Runmanager
from util import Util
import glob_conf
import plots
import ast # To convert strings to objects
import pandas as pd
from sklearn.preprocessing import LabelEncoder
from feats_spectra import Spectraloader
from scaler import Scaler
import pickle

Expand All @@ -19,7 +19,12 @@ class Experiment:


def __init__(self, config_obj):
"""Constructor: takes a name and the config object"""
"""
Parameters
----------
config_obj : a config parser object that sets the experiment parameters and being set as a global object.
"""

self.set_globals(config_obj)
self.name = glob_conf.config['EXP']['name']
self.util = Util()
Expand Down Expand Up @@ -91,21 +96,27 @@ def fill_train_and_tests(self):


def plot_distribution(self):
import plots
"""Plot the distribution of samples and speaker per target class and biological sex"""
fig_dir = self.util.get_path('fig_dir')
plots.describe_df(self.df_test, self.target, fig_dir+f'test_distplot.png')
plots.describe_df(self.df_train, self.target, fig_dir+f'train_distplot.png')

def augment_train(self):
# augment the train and dev dataframes
"""Augment the train dataframe"""
from augmenter import Augmenter
augment_train = Augmenter(self.df_train)
df_train_aug = augment_train.augment()
self.df_train = self.df_train.append(df_train_aug)


def extract_feats(self):
"""Extract the features for train and dev sets. They will be stored on disk and need to be removed manually."""
"""Extract the features for train and dev sets.
They will be stored on disk and need to be removed manually.
The string FEATS.feats_type is read from the config, defaults to os.
"""
df_train, df_test = self.df_train, self.df_test
strategy = self.util.config_val('DATA', 'strategy', 'train_test')
feats_type = self.util.config_val('FEATS', 'type', 'os')
Expand All @@ -132,16 +143,14 @@ def extract_feats(self):
self.feats_train = TRILLset(f'{feats_name}_train', df_train)
self.feats_train.extract()
self.feats_train.filter()
self.feats_test = AudIDset(f'{feats_name}_test', df_test)
self.feats_test = TRILLset(f'{feats_name}_test', df_test)
self.feats_test.extract()
self.feats_test.filter()
elif feats_type=='mld':
from feats_mld import MLD_set
self.feats_train = MLD_set(f'{feats_name}_train', df_train)
self.feats_train.extract()
self.feats_train.filter()
if self.feats_train.df.isna().to_numpy().any():
self.util.error('exp 1: NANs exist')
self.feats_test = MLD_set(f'{feats_name}_test', df_test)
self.feats_test.extract()
self.feats_test.filter()
Expand All @@ -152,6 +161,7 @@ def extract_feats(self):
self.util.error('exp 2: NANs exist')
elif feats_type=='spectra':
# compute the spectrograms
from feats_spectra import Spectraloader # not yet open source
test_specs = Spectraloader(f'{feats_name}_test', df_test)
test_specs.make_feats()
self.feats_test = test_specs.get_loader()
Expand Down
20 changes: 11 additions & 9 deletions src/feats_trill.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,19 @@
import os
import tensorflow as tf
# Import TF 2.X and make sure we're running eager.
# tf.enable_v2_behavior()
# assert tf.executing_eagerly()
assert tf.executing_eagerly()
import tensorflow_hub as hub
# Load the module and run inference.
@tf.function(experimental_relax_shapes=True)

class TRILLset(Featureset):
"""A feature extractor for the Google TRILL embeddings"""
"""https://ai.googleblog.com/2020/06/improving-speech-representations-and.html"""

def __init__(self, name, data_df):
self.name = name
self.data_df = data_df
self.util = Util()
self.module = hub.load('https://tfhub.dev/google/nonsemantic-speech-benchmark/trill/3')

def extract(self):
store = self.util.get_path('store')
storage = f'{store}{self.name}.pkl'
Expand All @@ -26,22 +29,21 @@ def extract(self):
except KeyError:
extract = False
if extract or not os.path.isfile(storage):
self.module = hub.load('https://tfhub.dev/google/nonsemantic-speech-benchmark/trill/3')
print('extracting TRILL embeddings, this might take a while...')

self.util.debug('extracting TRILL embeddings, this might take a while...')
emb_series = pd.Series(index = self.data_df.index, dtype=object)
length = len(self.data_df.index)
for idx, file in enumerate(self.data_df.index):
emb = self.getEmbeddings(file)
emb_series[idx] = emb
self.util.debug(f'TRILL: {length}, {idx}')
self.df = pd.DataFrame(emb_series, index=self.data_df.index)
self.df = pd.DataFrame(emb_series.values.tolist(), index=self.data_df.index)
self.df.to_pickle(storage)
try:
glob_conf.config['DATA']['needs_feature_extraction'] = 'false'
except KeyError:
pass
else:
self.util.debug('reusing extracted TRILL embeddings')
self.df = pd.read_pickle(storage)

def embed_wav(self, wav):
Expand All @@ -54,7 +56,7 @@ def embed_wav(self, wav):
def getEmbeddings(self, file):
wav = af.read(file)[0]
wav = tf.convert_to_tensor(wav)
emb_short = embed_wav(wav)
emb_short = self.embed_wav(wav)
# you get one embedding per frame, we use the mean for all the frames
emb_short = emb_short.numpy().mean(axis=0)
return emb_short
2 changes: 1 addition & 1 deletion src/runmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from model_svr import SVR_model
from model_xgb import XGB_model
from model_xgr import XGR_model
from model_cnn import CNN_model
from model_mlp import MLP_model
from model_mlp_regression import MLP_Reg_model
from reporter import Reporter
Expand Down Expand Up @@ -47,6 +46,7 @@ def do_runs(self):
elif model_type=='xgr':
self.model = XGR_model(self.df_train, self.df_test, self.feats_train, self.feats_test)
elif model_type=='cnn':
from model_cnn import CNN_model
self.model = CNN_model(self.df_train, self.df_test, self.feats_train, self.feats_test)
elif model_type=='mlp':
self.model = MLP_model(self.df_train, self.df_test, self.feats_train, self.feats_test)
Expand Down

0 comments on commit 6d30a4f

Please sign in to comment.