From d360f52863621b0c828f1762a1be0cbe53e89b68 Mon Sep 17 00:00:00 2001 From: Long Nguyen-Vu Date: Wed, 27 Dec 2023 14:16:19 +0900 Subject: [PATCH] test w2v-bert --- evaluate.py | 12 ++++----- models/__init__.py | 1 + models/w2vassist.py | 12 ++++----- oc_classifier.py | 62 ++++++++++++++++++++++++--------------------- oc_training.py | 25 ++---------------- 5 files changed, 48 insertions(+), 64 deletions(-) diff --git a/evaluate.py b/evaluate.py index 667fd27..a0ed09e 100644 --- a/evaluate.py +++ b/evaluate.py @@ -17,7 +17,7 @@ def load_metadata(file_path): lines = f.readlines() for line in lines: line = line.strip() - label = line.split(" ")[5] + label = line.split(" ")[2] labels.append(label) return labels @@ -40,8 +40,8 @@ def load_metadata_from_proto(meta_file_path, proto_file_path): lines = f.readlines() for line in lines: line = line.strip() - file_name = line.split(" ")[1] - label = line.split(" ")[5] + file_name = line.split(" ")[0] + label = line.split(" ")[2] if file_name in protos: index = protos.index(file_name) labels[index] = label @@ -167,8 +167,8 @@ def calculate_EER(scores, labels): # and bonafide otherwise # create two lists: one for the labels and one for the predictions - # labels = metadata - labels = load_metadata_from_proto(args.metadata_file, args.protocol_file) + labels = metadata + # labels = load_metadata_from_proto(args.metadata_file, args.protocol_file) predictions = [] for i, file_name in enumerate(proto): score = scores[i] @@ -194,5 +194,5 @@ def calculate_EER(scores, labels): print(f"TN = {cm[1][1]}") print(f"FP = {cm[0][1]}") print(f"FN = {cm[1][0]}") - + calculate_EER(scores, labels) diff --git a/models/__init__.py b/models/__init__.py index de07aca..91a6bf3 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -1,3 +1,4 @@ # from .lcnn import * # from .senet import * # from .occm import * +# from ..seamless_communication.src.seamless_communication.models.conformer_shaw import load_conformer_shaw_model \ No newline at end of file diff --git a/models/w2vassist.py b/models/w2vassist.py index aed55d2..7b9be82 100644 --- a/models/w2vassist.py +++ b/models/w2vassist.py @@ -12,12 +12,8 @@ from fairseq2.nn.padding import get_seqs_and_padding_mask from fairseq2.data import Collater from pathlib import Path -from seamless_communication.src.seamless_communication.models.conformer_shaw import load_conformer_shaw_model - - - - +from seamless_communication.src.seamless_communication.models.conformer_shaw import load_conformer_shaw_model ___author__ = "Long Nguyen-Vu" @@ -596,7 +592,8 @@ def forward(self, x): return emb, output -if __name__ == '__main__': + +def main(): import librosa model = AModel(None,"cuda").to("cuda") @@ -605,3 +602,6 @@ def forward(self, x): emb, out = model(torch.Tensor(audio_data).unsqueeze(0).to("cuda")) print(emb.shape) print(out.shape) + +if __name__ == '__main__': + main() diff --git a/oc_classifier.py b/oc_classifier.py index 67c0214..73201a2 100644 --- a/oc_classifier.py +++ b/oc_classifier.py @@ -1,27 +1,17 @@ import os import argparse -import librosa import torch import torch.nn.functional as F from torch.nn import DataParallel from torch.utils.data import Dataset -from torch.utils.data import DataLoader, random_split +from torch.utils.data import DataLoader -import numpy as np -from torchattacks import PGD +from models.w2vassist import * - -from models.lcnn import * -from models.senet import * -from models.xlsr import * -from models.sslassist import * - -from losses.custom_loss import compactness_loss, descriptiveness_loss, euclidean_distance_loss import warnings warnings.filterwarnings("ignore", category=DeprecationWarning) -# to be used with one-class classifier -# input is now a raw audio file + class ASVDataset(Dataset): @@ -51,9 +41,6 @@ def __init__(self, protocol_file, dataset_dir, eval=False): self.file_list = [] self.label_list = [] self.eval = eval - # file_list is now the second column of the protocol file - # label list is now the fifth column of the protocol file - # read the protocol file if self.eval: # collect all files @@ -73,12 +60,37 @@ def __init__(self, protocol_file, dataset_dir, eval=False): for line in lines: line = line.strip() line = line.split(" ") + # self.file_list.append(line[0]) + # self.label_list.append("bonafide") # bonafide only if line[4] == "bonafide": self.file_list.append(line[1]) self.label_list.append(line[4]) # bonafide only self._length = len(self.file_list) + def _preprocess_wav(self, audio_wav_path): + # Create a device and dtype + dtype = torch.float32 + device = torch.device("cuda") + audio_decoder = AudioDecoder(dtype=dtype, device=device) + fbank_converter = WaveformToFbankConverter( + num_mel_bins=80, + waveform_scale=2**15, + channel_last=True, + standardize=True, + device=device, + dtype=dtype, + ) + collater = Collater(pad_value=1) + + with Path(audio_wav_path).open("rb") as fb: + block = MemoryBlock(fb.read()) + # print("Decoding audio...") + decoded_audio = audio_decoder(block) + src = collater(fbank_converter(decoded_audio))["fbank"] + seqs, padding_mask = get_seqs_and_padding_mask(src) + return seqs + def __len__(self): return self._length @@ -89,8 +101,7 @@ def __getitem__(self, idx): file_path = os.path.join(self.dataset_dir, audio_file + ".flac") if not os.path.exists(file_path): file_path = os.path.join(self.dataset_dir, audio_file + ".wav") - - feature, _ = librosa.load(file_path, sr=None) + feature = self._preprocess_wav(file_path) feature_tensors = torch.tensor(feature, dtype=torch.float32) if self.eval == False: @@ -179,10 +190,10 @@ def create_reference_embedding2(model, dataloader, device): total_distances = [] with torch.no_grad(): - for _, (data, target) in enumerate(dataloader): + for _, (data, _) in enumerate(dataloader): data = data.to(device) - target = target.to(device) - emb, out = model(data) # torch.Size([1, 160]) + data.squeeze_(0) + emb, out = model(data) total_embeddings.append(emb) # reference embedding is the mean of all embeddings @@ -255,6 +266,7 @@ def score_eval_set_1c2(model, dataloader, device, reference_embedding, threshold with torch.no_grad(): for idx, (data, target) in enumerate(dataloader): data = data.to(device) + data.squeeze_(0) target = target.to(device) emb, out = model(data) print(f"Processing file counts: {idx} ...") @@ -264,7 +276,6 @@ def score_eval_set_1c2(model, dataloader, device, reference_embedding, threshold else: f.write(f"{float(distance)}, 0 \n") - def score_eval_set_2c1(extractor, encoder, dataloader, device): """TWO-CLASS APPROACH: Score the evaluation set and save the scores to a file @@ -338,24 +349,17 @@ def score_eval_set_2c2(model, dataloader, device): # load pretrained weights aasist.load_state_dict(torch.load(args.pretrained_sslaasist)) - # ssl.load_state_dict(torch.load(args.pretrained_ssl)) - # senet.load_state_dict(torch.load(args.pretrained_senet)) aasist = DataParallel(aasist) - # senet = DataParallel(senet) - # ssl = DataParallel(ssl) print("Pretrained weights loaded") # create a reference embedding & find a threshold train_dataset = ASVDataset(args.protocol_file, args.dataset_dir) train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=False, num_workers=0) - # reference_embedding, threshold = create_reference_embedding(ssl, senet, train_dataloader, device) reference_embedding, threshold = create_reference_embedding2(aasist, train_dataloader, device) # score the evaluation set eval_dataset = ASVDataset(args.eval_protocol_file, args.eval_dataset_dir, eval=True) eval_dataloader = DataLoader(eval_dataset, batch_size=1, shuffle=False, num_workers=0) - # score_eval_set(ssl, senet, eval_dataloader, device, reference_embedding, threshold) score_eval_set_1c2(aasist, eval_dataloader, device, reference_embedding, threshold) - # score_eval_set_2c2(aasist, eval_dataloader, device) print(f"threshold = {threshold}") diff --git a/oc_training.py b/oc_training.py index bbceeea..94b064c 100644 --- a/oc_training.py +++ b/oc_training.py @@ -335,38 +335,22 @@ def collate_fn(self, batch): # Define the collate function train_dataset = PFDataset(args.train_protocol_file, dataset_dir=args.train_dataset_dir) - # test_dataset = PFDataset(args.test_protocol_file, dataset_dir=args.test_dataset_dir) # Create dataloaders for training and validation batch_size = 1 print("Creating dataloaders...") - # train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, collate_fn=train_dataset.collate_fn) train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0) - # test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0, collate_fn=test_dataset.collate_fn) - print("Instantiating model...") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") aasist = AModel(None, device).to(device) - # ssl = SSLModel(device) - # senet34 = se_resnet34().to(device) - # lcnn = lcnn_net(asoftmax=False).to(device) + optimizer = optim.Adam(aasist.parameters(), lr=0.00001) - # optimizer = optim.Adam(list(ssl.parameters()) + list(senet34.parameters()) + list(lcnn.parameters()), lr=0.0001) - # optimizer = optim.Adam(list(ssl.parameters()) + list(senet34.parameters()), lr=0.00001, weight_decay=0.0005) aasist = DataParallel(aasist) - # ssl = DataParallel(ssl) - # senet34 = DataParallel(senet34) - # lcnn = DataParallel(lcnn) - - # if args.model == "lcnn_net_asoftmax": - # criterion = AngleLoss() - - # WandB – Initialize a new run wandb.init(project="oc_classifier-w2v", entity="longnv") @@ -382,9 +366,7 @@ def collate_fn(self, batch): # Training phase aasist.train() - # ssl.eval() - # senet34.train() - # lcnn.train() + running_loss = 0.0 running_closs = 0.0 @@ -422,7 +404,4 @@ def collate_fn(self, batch): wandb.log({"Epoch": epoch, "Train Loss": running_loss / (i+1), "Train Compactness Loss": running_closs / (i+1), "Train Descriptiveness Loss": running_dloss / (i+1)}) # save the models after each epoch print("Saving the models...") - # torch.save(ssl.module.state_dict(), f"ssl_vocoded_{epoch}.pt") - # torch.save(senet34.module.state_dict(), f"senet34_vocoded_{epoch}.pt") torch.save(aasist.module.state_dict(), f"w2v_vocoded_{epoch}.pt") - # torch.save(lcnn.module.state_dict(), f"lcnn_{epoch}.pt") \ No newline at end of file