Skip to content

Commit

Permalink
test w2v-bert
Browse files Browse the repository at this point in the history
  • Loading branch information
nguyenvulong committed Dec 27, 2023
1 parent 8b12fc5 commit d360f52
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 64 deletions.
12 changes: 6 additions & 6 deletions evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def load_metadata(file_path):
lines = f.readlines()
for line in lines:
line = line.strip()
label = line.split(" ")[5]
label = line.split(" ")[2]
labels.append(label)
return labels

Expand All @@ -40,8 +40,8 @@ def load_metadata_from_proto(meta_file_path, proto_file_path):
lines = f.readlines()
for line in lines:
line = line.strip()
file_name = line.split(" ")[1]
label = line.split(" ")[5]
file_name = line.split(" ")[0]
label = line.split(" ")[2]
if file_name in protos:
index = protos.index(file_name)
labels[index] = label
Expand Down Expand Up @@ -167,8 +167,8 @@ def calculate_EER(scores, labels):
# and bonafide otherwise

# create two lists: one for the labels and one for the predictions
# labels = metadata
labels = load_metadata_from_proto(args.metadata_file, args.protocol_file)
labels = metadata
# labels = load_metadata_from_proto(args.metadata_file, args.protocol_file)
predictions = []
for i, file_name in enumerate(proto):
score = scores[i]
Expand All @@ -194,5 +194,5 @@ def calculate_EER(scores, labels):
print(f"TN = {cm[1][1]}")
print(f"FP = {cm[0][1]}")
print(f"FN = {cm[1][0]}")

calculate_EER(scores, labels)
1 change: 1 addition & 0 deletions models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# from .lcnn import *
# from .senet import *
# from .occm import *
# from ..seamless_communication.src.seamless_communication.models.conformer_shaw import load_conformer_shaw_model
12 changes: 6 additions & 6 deletions models/w2vassist.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,8 @@
from fairseq2.nn.padding import get_seqs_and_padding_mask
from fairseq2.data import Collater
from pathlib import Path
from seamless_communication.src.seamless_communication.models.conformer_shaw import load_conformer_shaw_model





from seamless_communication.src.seamless_communication.models.conformer_shaw import load_conformer_shaw_model


___author__ = "Long Nguyen-Vu"
Expand Down Expand Up @@ -596,7 +592,8 @@ def forward(self, x):

return emb, output

if __name__ == '__main__':

def main():
import librosa

model = AModel(None,"cuda").to("cuda")
Expand All @@ -605,3 +602,6 @@ def forward(self, x):
emb, out = model(torch.Tensor(audio_data).unsqueeze(0).to("cuda"))
print(emb.shape)
print(out.shape)

if __name__ == '__main__':
main()
62 changes: 33 additions & 29 deletions oc_classifier.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,17 @@
import os
import argparse
import librosa
import torch
import torch.nn.functional as F
from torch.nn import DataParallel
from torch.utils.data import Dataset
from torch.utils.data import DataLoader, random_split
from torch.utils.data import DataLoader

import numpy as np
from torchattacks import PGD
from models.w2vassist import *


from models.lcnn import *
from models.senet import *
from models.xlsr import *
from models.sslassist import *

from losses.custom_loss import compactness_loss, descriptiveness_loss, euclidean_distance_loss
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

# to be used with one-class classifier
# input is now a raw audio file



class ASVDataset(Dataset):
Expand Down Expand Up @@ -51,9 +41,6 @@ def __init__(self, protocol_file, dataset_dir, eval=False):
self.file_list = []
self.label_list = []
self.eval = eval
# file_list is now the second column of the protocol file
# label list is now the fifth column of the protocol file
# read the protocol file

if self.eval:
# collect all files
Expand All @@ -73,12 +60,37 @@ def __init__(self, protocol_file, dataset_dir, eval=False):
for line in lines:
line = line.strip()
line = line.split(" ")
# self.file_list.append(line[0])
# self.label_list.append("bonafide") # bonafide only
if line[4] == "bonafide":
self.file_list.append(line[1])
self.label_list.append(line[4]) # bonafide only

self._length = len(self.file_list)
def _preprocess_wav(self, audio_wav_path):
# Create a device and dtype
dtype = torch.float32
device = torch.device("cuda")
audio_decoder = AudioDecoder(dtype=dtype, device=device)
fbank_converter = WaveformToFbankConverter(
num_mel_bins=80,
waveform_scale=2**15,
channel_last=True,
standardize=True,
device=device,
dtype=dtype,
)
collater = Collater(pad_value=1)


with Path(audio_wav_path).open("rb") as fb:
block = MemoryBlock(fb.read())
# print("Decoding audio...")
decoded_audio = audio_decoder(block)
src = collater(fbank_converter(decoded_audio))["fbank"]
seqs, padding_mask = get_seqs_and_padding_mask(src)
return seqs

def __len__(self):
return self._length

Expand All @@ -89,8 +101,7 @@ def __getitem__(self, idx):
file_path = os.path.join(self.dataset_dir, audio_file + ".flac")
if not os.path.exists(file_path):
file_path = os.path.join(self.dataset_dir, audio_file + ".wav")

feature, _ = librosa.load(file_path, sr=None)
feature = self._preprocess_wav(file_path)
feature_tensors = torch.tensor(feature, dtype=torch.float32)

if self.eval == False:
Expand Down Expand Up @@ -179,10 +190,10 @@ def create_reference_embedding2(model, dataloader, device):
total_distances = []

with torch.no_grad():
for _, (data, target) in enumerate(dataloader):
for _, (data, _) in enumerate(dataloader):
data = data.to(device)
target = target.to(device)
emb, out = model(data) # torch.Size([1, 160])
data.squeeze_(0)
emb, out = model(data)
total_embeddings.append(emb)

# reference embedding is the mean of all embeddings
Expand Down Expand Up @@ -255,6 +266,7 @@ def score_eval_set_1c2(model, dataloader, device, reference_embedding, threshold
with torch.no_grad():
for idx, (data, target) in enumerate(dataloader):
data = data.to(device)
data.squeeze_(0)
target = target.to(device)
emb, out = model(data)
print(f"Processing file counts: {idx} ...")
Expand All @@ -264,7 +276,6 @@ def score_eval_set_1c2(model, dataloader, device, reference_embedding, threshold
else:
f.write(f"{float(distance)}, 0 \n")


def score_eval_set_2c1(extractor, encoder, dataloader, device):
"""TWO-CLASS APPROACH:
Score the evaluation set and save the scores to a file
Expand Down Expand Up @@ -338,24 +349,17 @@ def score_eval_set_2c2(model, dataloader, device):

# load pretrained weights
aasist.load_state_dict(torch.load(args.pretrained_sslaasist))
# ssl.load_state_dict(torch.load(args.pretrained_ssl))
# senet.load_state_dict(torch.load(args.pretrained_senet))
aasist = DataParallel(aasist)
# senet = DataParallel(senet)
# ssl = DataParallel(ssl)
print("Pretrained weights loaded")

# create a reference embedding & find a threshold
train_dataset = ASVDataset(args.protocol_file, args.dataset_dir)
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=False, num_workers=0)
# reference_embedding, threshold = create_reference_embedding(ssl, senet, train_dataloader, device)
reference_embedding, threshold = create_reference_embedding2(aasist, train_dataloader, device)

# score the evaluation set
eval_dataset = ASVDataset(args.eval_protocol_file, args.eval_dataset_dir, eval=True)
eval_dataloader = DataLoader(eval_dataset, batch_size=1, shuffle=False, num_workers=0)
# score_eval_set(ssl, senet, eval_dataloader, device, reference_embedding, threshold)
score_eval_set_1c2(aasist, eval_dataloader, device, reference_embedding, threshold)
# score_eval_set_2c2(aasist, eval_dataloader, device)

print(f"threshold = {threshold}")
25 changes: 2 additions & 23 deletions oc_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,38 +335,22 @@ def collate_fn(self, batch):
# Define the collate function

train_dataset = PFDataset(args.train_protocol_file, dataset_dir=args.train_dataset_dir)
# test_dataset = PFDataset(args.test_protocol_file, dataset_dir=args.test_dataset_dir)

# Create dataloaders for training and validation
batch_size = 1

print("Creating dataloaders...")
# train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, collate_fn=train_dataset.collate_fn)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)

# test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0, collate_fn=test_dataset.collate_fn)

print("Instantiating model...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

aasist = AModel(None, device).to(device)
# ssl = SSLModel(device)
# senet34 = se_resnet34().to(device)
# lcnn = lcnn_net(asoftmax=False).to(device)

optimizer = optim.Adam(aasist.parameters(), lr=0.00001)
# optimizer = optim.Adam(list(ssl.parameters()) + list(senet34.parameters()) + list(lcnn.parameters()), lr=0.0001)
# optimizer = optim.Adam(list(ssl.parameters()) + list(senet34.parameters()), lr=0.00001, weight_decay=0.0005)

aasist = DataParallel(aasist)

# ssl = DataParallel(ssl)
# senet34 = DataParallel(senet34)
# lcnn = DataParallel(lcnn)

# if args.model == "lcnn_net_asoftmax":
# criterion = AngleLoss()


# WandB – Initialize a new run
wandb.init(project="oc_classifier-w2v", entity="longnv")

Expand All @@ -382,9 +366,7 @@ def collate_fn(self, batch):

# Training phase
aasist.train()
# ssl.eval()
# senet34.train()
# lcnn.train()


running_loss = 0.0
running_closs = 0.0
Expand Down Expand Up @@ -422,7 +404,4 @@ def collate_fn(self, batch):
wandb.log({"Epoch": epoch, "Train Loss": running_loss / (i+1), "Train Compactness Loss": running_closs / (i+1), "Train Descriptiveness Loss": running_dloss / (i+1)})
# save the models after each epoch
print("Saving the models...")
# torch.save(ssl.module.state_dict(), f"ssl_vocoded_{epoch}.pt")
# torch.save(senet34.module.state_dict(), f"senet34_vocoded_{epoch}.pt")
torch.save(aasist.module.state_dict(), f"w2v_vocoded_{epoch}.pt")
# torch.save(lcnn.module.state_dict(), f"lcnn_{epoch}.pt")

0 comments on commit d360f52

Please sign in to comment.