diff --git a/models/__init__.py b/models/__init__.py index 1033995..de07aca 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -1,3 +1,3 @@ -from .lcnn import * -from .senet import * -from .occm import * +# from .lcnn import * +# from .senet import * +# from .occm import * diff --git a/models/sslassist.py b/models/sslassist.py index ff80121..055d359 100644 --- a/models/sslassist.py +++ b/models/sslassist.py @@ -21,7 +21,7 @@ class SSLModel(nn.Module): def __init__(self,device): super(SSLModel, self).__init__() - cp_path = '/datac/longnv/SSL_Anti-spoofing/pretrained/xlsr2_300m.pt' # Change the pre-trained XLSR model path. + cp_path = '/home/longnv/BTS-Encoder-ASVspoof/demo2023/pretrained/xlsr2_300m.pt' # Change the pre-trained XLSR model path. model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp_path]) self.model = model[0] self.device=device diff --git a/models/w2vassist.py b/models/w2vassist.py new file mode 100644 index 0000000..aed55d2 --- /dev/null +++ b/models/w2vassist.py @@ -0,0 +1,607 @@ +import random +from typing import Union +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + +import fairseq2 +from fairseq2.data.audio import AudioDecoder, WaveformToFbankConverter +from fairseq2.memory import MemoryBlock +from fairseq2.nn.padding import get_seqs_and_padding_mask +from fairseq2.data import Collater +from pathlib import Path +from seamless_communication.src.seamless_communication.models.conformer_shaw import load_conformer_shaw_model + + + + + + + +___author__ = "Long Nguyen-Vu" +__email__ = "long@ssu.ac.kr" + +############################ +## FOR fine-tuned SSL MODEL +############################ + + +class SSLModel(nn.Module): + def __init__(self,device): + super(SSLModel, self).__init__() + self.dtype = torch.float32 + self.device=device + self.out_dim = 1024 + self.model = load_conformer_shaw_model("conformer_shaw", device=device, dtype=self.dtype) + self.model.eval() + + return + + def extract_feat(self, seqs): + # with torch.inference_mode(): + with torch.no_grad(): + seqs, padding_mask = self.model.encoder_frontend(seqs, None) + seqs, padding_mask = self.model.encoder(seqs, None) + + return seqs + + +#---------AASIST back-end------------------------# +''' Jee-weon Jung, Hee-Soo Heo, Hemlata Tak, Hye-jin Shim, Joon Son Chung, Bong-Jin Lee, Ha-Jin Yu and Nicholas Evans. + AASIST: Audio Anti-Spoofing Using Integrated Spectro-Temporal Graph Attention Networks. + In Proc. ICASSP 2022, pp: 6367--6371.''' + + +class GraphAttentionLayer(nn.Module): + def __init__(self, in_dim, out_dim, **kwargs): + super().__init__() + + # attention map + self.att_proj = nn.Linear(in_dim, out_dim) + self.att_weight = self._init_new_params(out_dim, 1) + + # project + self.proj_with_att = nn.Linear(in_dim, out_dim) + self.proj_without_att = nn.Linear(in_dim, out_dim) + + # batch norm + self.bn = nn.BatchNorm1d(out_dim) + + # dropout for inputs + self.input_drop = nn.Dropout(p=0.2) + + # activate + self.act = nn.SELU(inplace=True) + + # temperature + self.temp = 1. + if "temperature" in kwargs: + self.temp = kwargs["temperature"] + + def forward(self, x): + ''' + x :(#bs, #node, #dim) + ''' + # apply input dropout + x = self.input_drop(x) + + # derive attention map + att_map = self._derive_att_map(x) + + # projection + x = self._project(x, att_map) + + # apply batch norm + x = self._apply_BN(x) + x = self.act(x) + return x + + def _pairwise_mul_nodes(self, x): + ''' + Calculates pairwise multiplication of nodes. + - for attention map + x :(#bs, #node, #dim) + out_shape :(#bs, #node, #node, #dim) + ''' + + nb_nodes = x.size(1) + x = x.unsqueeze(2).expand(-1, -1, nb_nodes, -1) + x_mirror = x.transpose(1, 2) + + return x * x_mirror + + def _derive_att_map(self, x): + ''' + x :(#bs, #node, #dim) + out_shape :(#bs, #node, #node, 1) + ''' + att_map = self._pairwise_mul_nodes(x) + # size: (#bs, #node, #node, #dim_out) + att_map = torch.tanh(self.att_proj(att_map)) + # size: (#bs, #node, #node, 1) + att_map = torch.matmul(att_map, self.att_weight) + + # apply temperature + att_map = att_map / self.temp + + att_map = F.softmax(att_map, dim=-2) + + return att_map + + def _project(self, x, att_map): + x1 = self.proj_with_att(torch.matmul(att_map.squeeze(-1), x)) + x2 = self.proj_without_att(x) + + return x1 + x2 + + def _apply_BN(self, x): + org_size = x.size() + x = x.view(-1, org_size[-1]) + x = self.bn(x) + x = x.view(org_size) + + return x + + def _init_new_params(self, *size): + out = nn.Parameter(torch.FloatTensor(*size)) + nn.init.xavier_normal_(out) + return out + + +class HtrgGraphAttentionLayer(nn.Module): + def __init__(self, in_dim, out_dim, **kwargs): + super().__init__() + + self.proj_type1 = nn.Linear(in_dim, in_dim) + self.proj_type2 = nn.Linear(in_dim, in_dim) + + # attention map + self.att_proj = nn.Linear(in_dim, out_dim) + self.att_projM = nn.Linear(in_dim, out_dim) + + self.att_weight11 = self._init_new_params(out_dim, 1) + self.att_weight22 = self._init_new_params(out_dim, 1) + self.att_weight12 = self._init_new_params(out_dim, 1) + self.att_weightM = self._init_new_params(out_dim, 1) + + # project + self.proj_with_att = nn.Linear(in_dim, out_dim) + self.proj_without_att = nn.Linear(in_dim, out_dim) + + self.proj_with_attM = nn.Linear(in_dim, out_dim) + self.proj_without_attM = nn.Linear(in_dim, out_dim) + + # batch norm + self.bn = nn.BatchNorm1d(out_dim) + + # dropout for inputs + self.input_drop = nn.Dropout(p=0.2) + + # activate + self.act = nn.SELU(inplace=True) + + # temperature + self.temp = 1. + if "temperature" in kwargs: + self.temp = kwargs["temperature"] + + def forward(self, x1, x2, master=None): + ''' + x1 :(#bs, #node, #dim) + x2 :(#bs, #node, #dim) + ''' + #print('x1',x1.shape) + #print('x2',x2.shape) + num_type1 = x1.size(1) + num_type2 = x2.size(1) + #print('num_type1',num_type1) + #print('num_type2',num_type2) + x1 = self.proj_type1(x1) + #print('proj_type1',x1.shape) + x2 = self.proj_type2(x2) + #print('proj_type2',x2.shape) + x = torch.cat([x1, x2], dim=1) + #print('Concat x1 and x2',x.shape) + + if master is None: + master = torch.mean(x, dim=1, keepdim=True) + #print('master',master.shape) + # apply input dropout + x = self.input_drop(x) + + # derive attention map + att_map = self._derive_att_map(x, num_type1, num_type2) + #print('master',master.shape) + # directional edge for master node + master = self._update_master(x, master) + #print('master',master.shape) + # projection + x = self._project(x, att_map) + #print('proj x',x.shape) + # apply batch norm + x = self._apply_BN(x) + x = self.act(x) + + x1 = x.narrow(1, 0, num_type1) + #print('x1',x1.shape) + x2 = x.narrow(1, num_type1, num_type2) + #print('x2',x2.shape) + return x1, x2, master + + def _update_master(self, x, master): + + att_map = self._derive_att_map_master(x, master) + master = self._project_master(x, master, att_map) + + return master + + def _pairwise_mul_nodes(self, x): + ''' + Calculates pairwise multiplication of nodes. + - for attention map + x :(#bs, #node, #dim) + out_shape :(#bs, #node, #node, #dim) + ''' + + nb_nodes = x.size(1) + x = x.unsqueeze(2).expand(-1, -1, nb_nodes, -1) + x_mirror = x.transpose(1, 2) + + return x * x_mirror + + def _derive_att_map_master(self, x, master): + ''' + x :(#bs, #node, #dim) + out_shape :(#bs, #node, #node, 1) + ''' + att_map = x * master + att_map = torch.tanh(self.att_projM(att_map)) + + att_map = torch.matmul(att_map, self.att_weightM) + + # apply temperature + att_map = att_map / self.temp + + att_map = F.softmax(att_map, dim=-2) + + return att_map + + def _derive_att_map(self, x, num_type1, num_type2): + ''' + x :(#bs, #node, #dim) + out_shape :(#bs, #node, #node, 1) + ''' + att_map = self._pairwise_mul_nodes(x) + # size: (#bs, #node, #node, #dim_out) + att_map = torch.tanh(self.att_proj(att_map)) + # size: (#bs, #node, #node, 1) + + att_board = torch.zeros_like(att_map[:, :, :, 0]).unsqueeze(-1) + + att_board[:, :num_type1, :num_type1, :] = torch.matmul( + att_map[:, :num_type1, :num_type1, :], self.att_weight11) + att_board[:, num_type1:, num_type1:, :] = torch.matmul( + att_map[:, num_type1:, num_type1:, :], self.att_weight22) + att_board[:, :num_type1, num_type1:, :] = torch.matmul( + att_map[:, :num_type1, num_type1:, :], self.att_weight12) + att_board[:, num_type1:, :num_type1, :] = torch.matmul( + att_map[:, num_type1:, :num_type1, :], self.att_weight12) + + att_map = att_board + + + + # apply temperature + att_map = att_map / self.temp + + att_map = F.softmax(att_map, dim=-2) + + return att_map + + def _project(self, x, att_map): + x1 = self.proj_with_att(torch.matmul(att_map.squeeze(-1), x)) + x2 = self.proj_without_att(x) + + return x1 + x2 + + def _project_master(self, x, master, att_map): + + x1 = self.proj_with_attM(torch.matmul( + att_map.squeeze(-1).unsqueeze(1), x)) + x2 = self.proj_without_attM(master) + + return x1 + x2 + + def _apply_BN(self, x): + org_size = x.size() + x = x.view(-1, org_size[-1]) + x = self.bn(x) + x = x.view(org_size) + + return x + + def _init_new_params(self, *size): + out = nn.Parameter(torch.FloatTensor(*size)) + nn.init.xavier_normal_(out) + return out + + +class GraphPool(nn.Module): + def __init__(self, k: float, in_dim: int, p: Union[float, int]): + super().__init__() + self.k = k + self.sigmoid = nn.Sigmoid() + self.proj = nn.Linear(in_dim, 1) + self.drop = nn.Dropout(p=p) if p > 0 else nn.Identity() + self.in_dim = in_dim + + def forward(self, h): + Z = self.drop(h) + weights = self.proj(Z) + scores = self.sigmoid(weights) + new_h = self.top_k_graph(scores, h, self.k) + + return new_h + + def top_k_graph(self, scores, h, k): + """ + args + ===== + scores: attention-based weights (#bs, #node, 1) + h: graph data (#bs, #node, #dim) + k: ratio of remaining nodes, (float) + returns + ===== + h: graph pool applied data (#bs, #node', #dim) + """ + _, n_nodes, n_feat = h.size() + n_nodes = max(int(n_nodes * k), 1) + _, idx = torch.topk(scores, n_nodes, dim=1) + idx = idx.expand(-1, -1, n_feat) + + h = h * scores + h = torch.gather(h, 1, idx) + + return h + + + + +class Residual_block(nn.Module): + def __init__(self, nb_filts, first=False): + super().__init__() + self.first = first + + if not self.first: + self.bn1 = nn.BatchNorm2d(num_features=nb_filts[0]) + self.conv1 = nn.Conv2d(in_channels=nb_filts[0], + out_channels=nb_filts[1], + kernel_size=(2, 3), + padding=(1, 1), + stride=1) + self.selu = nn.SELU(inplace=True) + + self.bn2 = nn.BatchNorm2d(num_features=nb_filts[1]) + self.conv2 = nn.Conv2d(in_channels=nb_filts[1], + out_channels=nb_filts[1], + kernel_size=(2, 3), + padding=(0, 1), + stride=1) + + if nb_filts[0] != nb_filts[1]: + self.downsample = True + self.conv_downsample = nn.Conv2d(in_channels=nb_filts[0], + out_channels=nb_filts[1], + padding=(0, 1), + kernel_size=(1, 3), + stride=1) + + else: + self.downsample = False + + + def forward(self, x): + identity = x + if not self.first: + out = self.bn1(x) + out = self.selu(out) + else: + out = x + + #print('out',out.shape) + out = self.conv1(x) + + #print('aft conv1 out',out.shape) + out = self.bn2(out) + out = self.selu(out) + # print('out',out.shape) + out = self.conv2(out) + #print('conv2 out',out.shape) + + if self.downsample: + identity = self.conv_downsample(identity) + + out += identity + #out = self.mp(out) + return out + + +class AModel(nn.Module): + def __init__(self, args, device): + super().__init__() + self.device = device + + # AASIST parameters + filts = [128, [1, 32], [32, 32], [32, 64], [64, 64]] + gat_dims = [64, 32] + pool_ratios = [0.5, 0.5, 0.5, 0.5] + temperatures = [2.0, 2.0, 100.0, 100.0] + + + #### + # create network wav2vec 2.0 + #### + self.ssl_model = SSLModel(self.device) + self.LL = nn.Linear(self.ssl_model.out_dim, 128) + + self.first_bn = nn.BatchNorm2d(num_features=1) + self.first_bn1 = nn.BatchNorm2d(num_features=64) + self.drop = nn.Dropout(0.5, inplace=True) + self.drop_way = nn.Dropout(0.2, inplace=True) + self.selu = nn.SELU(inplace=True) + + # RawNet2 encoder + self.encoder = nn.Sequential( + nn.Sequential(Residual_block(nb_filts=filts[1], first=True)), + nn.Sequential(Residual_block(nb_filts=filts[2])), + nn.Sequential(Residual_block(nb_filts=filts[3])), + nn.Sequential(Residual_block(nb_filts=filts[4])), + nn.Sequential(Residual_block(nb_filts=filts[4])), + nn.Sequential(Residual_block(nb_filts=filts[4]))) + + self.attention = nn.Sequential( + nn.Conv2d(64, 128, kernel_size=(1,1)), + nn.SELU(inplace=True), + nn.BatchNorm2d(128), + nn.Conv2d(128, 64, kernel_size=(1,1)), + + ) + # position encoding + self.pos_S = nn.Parameter(torch.randn(1, 42, filts[-1][-1])) + + self.master1 = nn.Parameter(torch.randn(1, 1, gat_dims[0])) + self.master2 = nn.Parameter(torch.randn(1, 1, gat_dims[0])) + + # Graph module + self.GAT_layer_S = GraphAttentionLayer(filts[-1][-1], + gat_dims[0], + temperature=temperatures[0]) + self.GAT_layer_T = GraphAttentionLayer(filts[-1][-1], + gat_dims[0], + temperature=temperatures[1]) + # HS-GAL layer + self.HtrgGAT_layer_ST11 = HtrgGraphAttentionLayer( + gat_dims[0], gat_dims[1], temperature=temperatures[2]) + self.HtrgGAT_layer_ST12 = HtrgGraphAttentionLayer( + gat_dims[1], gat_dims[1], temperature=temperatures[2]) + self.HtrgGAT_layer_ST21 = HtrgGraphAttentionLayer( + gat_dims[0], gat_dims[1], temperature=temperatures[2]) + self.HtrgGAT_layer_ST22 = HtrgGraphAttentionLayer( + gat_dims[1], gat_dims[1], temperature=temperatures[2]) + + # Graph pooling layers + self.pool_S = GraphPool(pool_ratios[0], gat_dims[0], 0.3) + self.pool_T = GraphPool(pool_ratios[1], gat_dims[0], 0.3) + self.pool_hS1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3) + self.pool_hT1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3) + + self.pool_hS2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3) + self.pool_hT2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3) + + self.out_layer = nn.Linear(5 * gat_dims[1], 2) + + def forward(self, x): + #-------pre-trained Wav2vec model fine tunning ------------------------## + # x_ssl_feat = self.ssl_model.extract_feat(x.squeeze(-1)) + x_ssl_feat = self.ssl_model.extract_feat(x) + x = self.LL(x_ssl_feat) #(bs,frame_number,feat_out_dim) + + # post-processing on front-end features + x = x.transpose(1, 2) #(bs,feat_out_dim,frame_number) + x = x.unsqueeze(dim=1) # add channel + x = F.max_pool2d(x, (3, 3)) + x = self.first_bn(x) + x = self.selu(x) + + # RawNet2-based encoder + x = self.encoder(x) + x = self.first_bn1(x) + x = self.selu(x) + + w = self.attention(x) + + #------------SA for spectral feature-------------# + w1 = F.softmax(w,dim=-1) + m = torch.sum(x * w1, dim=-1) + e_S = m.transpose(1, 2) + self.pos_S + + # graph module layer + gat_S = self.GAT_layer_S(e_S) + out_S = self.pool_S(gat_S) # (#bs, #node, #dim) + + #------------SA for temporal feature-------------# + w2 = F.softmax(w,dim=-2) + m1 = torch.sum(x * w2, dim=-2) + + e_T = m1.transpose(1, 2) + + # graph module layer + gat_T = self.GAT_layer_T(e_T) + out_T = self.pool_T(gat_T) + + # learnable master node + master1 = self.master1.expand(x.size(0), -1, -1) + master2 = self.master2.expand(x.size(0), -1, -1) + + # inference 1 + out_T1, out_S1, master1 = self.HtrgGAT_layer_ST11( + out_T, out_S, master=self.master1) + + out_S1 = self.pool_hS1(out_S1) + out_T1 = self.pool_hT1(out_T1) + + out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST12( + out_T1, out_S1, master=master1) + out_T1 = out_T1 + out_T_aug + out_S1 = out_S1 + out_S_aug + master1 = master1 + master_aug + + # inference 2 + out_T2, out_S2, master2 = self.HtrgGAT_layer_ST21( + out_T, out_S, master=self.master2) + out_S2 = self.pool_hS2(out_S2) + out_T2 = self.pool_hT2(out_T2) + + out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST22( + out_T2, out_S2, master=master2) + out_T2 = out_T2 + out_T_aug + out_S2 = out_S2 + out_S_aug + master2 = master2 + master_aug + + out_T1 = self.drop_way(out_T1) + out_T2 = self.drop_way(out_T2) + out_S1 = self.drop_way(out_S1) + out_S2 = self.drop_way(out_S2) + master1 = self.drop_way(master1) + master2 = self.drop_way(master2) + + out_T = torch.max(out_T1, out_T2) + out_S = torch.max(out_S1, out_S2) + master = torch.max(master1, master2) + + # Readout operation + T_max, _ = torch.max(torch.abs(out_T), dim=1) + T_avg = torch.mean(out_T, dim=1) + + S_max, _ = torch.max(torch.abs(out_S), dim=1) + S_avg = torch.mean(out_S, dim=1) + + emb = last_hidden = torch.cat( + [T_max, T_avg, S_max, S_avg, master.squeeze(1)], dim=1) + + last_hidden = self.drop(last_hidden) + output = self.out_layer(last_hidden) + + return emb, output + +if __name__ == '__main__': + import librosa + + model = AModel(None,"cuda").to("cuda") + audio_file = "/datac/longnv/audio_samples/ADD2023_T2_T_00000000.wav" + audio_data, _ = librosa.load(audio_file, sr=None) + emb, out = model(torch.Tensor(audio_data).unsqueeze(0).to("cuda")) + print(emb.shape) + print(out.shape) diff --git a/oc_training.py b/oc_training.py index 37562c4..bbceeea 100644 --- a/oc_training.py +++ b/oc_training.py @@ -14,16 +14,26 @@ from torch.utils.data import DataLoader, random_split from sklearn.utils.class_weight import compute_class_weight -from models.lcnn import * -from models.senet import * -from models.xlsr import * -from models.sslassist import * + +import fairseq2 +from fairseq2.data.audio import AudioDecoder, WaveformToFbankConverter +from fairseq2.memory import MemoryBlock +from fairseq2.nn.padding import get_seqs_and_padding_mask +from fairseq2.data import Collater +from pathlib import Path + + +# from models.lcnn import * +# from models.senet import * +# from models.xlsr import * +# from models.sslassist import * +from models.w2vassist import * from losses.custom_loss import compactness_loss, descriptiveness_loss, euclidean_distance_loss import torch.nn.functional as F -from torchattacks import PGD +# from torchattacks import PGD from torch.utils.data import Dataset from data_utils_SSL import process_Rawboost_feature @@ -195,6 +205,30 @@ def _get_files(self, idx): 'spoof1': spoof_files[0], # The first spoof file } + def _preprocess_wav(self, audio_wav_path): + # Create a device and dtype + dtype = torch.float32 + device = torch.device("cuda") + audio_decoder = AudioDecoder(dtype=dtype, device=device) + fbank_converter = WaveformToFbankConverter( + num_mel_bins=80, + waveform_scale=2**15, + channel_last=True, + standardize=True, + device=device, + dtype=dtype, + ) + collater = Collater(pad_value=1) + + + with Path(audio_wav_path).open("rb") as fb: + block = MemoryBlock(fb.read()) + # print("Decoding audio...") + decoded_audio = audio_decoder(block) + src = collater(fbank_converter(decoded_audio))["fbank"] + seqs, padding_mask = get_seqs_and_padding_mask(src) + return seqs + def __len__(self): return self._length @@ -216,10 +250,11 @@ def __getitem__(self, idx): # file_path = os.path.join(self.dataset_dir, audio_file + ".flac") file_path = os.path.join(self.dataset_dir, audio_file + ".wav") - feature, sr = librosa.load(file_path, sr=None) + # feature, sr = librosa.load(file_path, sr=None) + feature = self._preprocess_wav(file_path) # rawboost augmentation, algo=4 is the series of 1, 2, 3 # feature = process_Rawboost_feature(feature, sr, self.args, 5) - max_length = max(max_length, feature.shape[0]) + max_length = max(max_length, feature.shape[1]) # Convert label "spoof" = 1 and "bonafide" = 0 label = 1 if key.startswith("spoof") else 0 @@ -233,10 +268,12 @@ def __getitem__(self, idx): vocoded_files = self._get_vocoded_files(file_assignments['bona1']) for vocoded_file in vocoded_files: file_path = os.path.join(self._vocoded_dir, vocoded_file + ".wav") - feature, sr = librosa.load(file_path, sr=None) + # feature, sr = librosa.load(file_path, sr=None) + feature = self._preprocess_wav(file_path) # rawboost augmentation, algo=4 is the series of 1, 2, 3 # feature = process_Rawboost_feature(feature, sr, self.args, 5) - max_length = max(max_length, feature.shape[0]) + + max_length = max(max_length, feature.shape[1]) label = 1 features.append(feature) labels.append(label) @@ -244,16 +281,11 @@ def __getitem__(self, idx): # Pad the features to have the same length features_padded = [] for feature in features: - # You might want to specify the type of padding, e.g., zero padding - feature_padded = np.pad(feature, (0, max_length - len(feature)), mode='constant') + feature_padded = F.pad(feature, (0, 0, max_length - feature.size(1), 0)) features_padded.append(feature_padded) - # Convert the list of features and labels to tensors - features = np.array(features_padded) - labels = np.array(labels) - feature_tensors = torch.tensor(features, dtype=torch.float32) - label_tensors = torch.tensor(labels, dtype=torch.int64) - return feature_tensors, label_tensors + + return torch.stack(features_padded), torch.tensor(labels, dtype=torch.int64) def collate_fn(self, batch): """pad the time series 1D""" @@ -336,7 +368,7 @@ def collate_fn(self, batch): # WandB – Initialize a new run - wandb.init(project="oc_classifier", entity="longnv") + wandb.init(project="oc_classifier-w2v", entity="longnv") # Number of epochs num_epochs = 100 @@ -362,16 +394,10 @@ def collate_fn(self, batch): for i, data in enumerate(train_dataloader, 0): inputs, labels = data[0].to(device), data[1].to(device) # torch.Size([1, 8, 71648]), torch.Size([1, 8]) - # print(f"inputs.shape = {inputs.shape}, labels.shape = {labels.shape}") inputs = inputs.squeeze(0) # torch.Size([12, 81204]) + inputs = inputs.squeeze(1) optimizer.zero_grad() - # Forward pass - # outputs_ssl = ssl(inputs) # torch.Size([12, 191, 1024]) - # outputs_ssl = outputs_ssl.unsqueeze(1) # torch.Size([12, 1, 191, 1024]) - - # outputs_senet34 = senet34(outputs_ssl) # torch.Size([12, 128]) - # outputs_lcnn = lcnn(outputs_ssl) # torch.Size([8, 2]) outputs_senet34 = outputs_aasist = aasist(inputs) # torch.Size([12, 128]) com = outputs_senet34[0] des = outputs_senet34[1] @@ -398,5 +424,5 @@ def collate_fn(self, batch): print("Saving the models...") # torch.save(ssl.module.state_dict(), f"ssl_vocoded_{epoch}.pt") # torch.save(senet34.module.state_dict(), f"senet34_vocoded_{epoch}.pt") - torch.save(aasist.module.state_dict(), f"aasist_vocoded_{epoch}.pt") + torch.save(aasist.module.state_dict(), f"w2v_vocoded_{epoch}.pt") # torch.save(lcnn.module.state_dict(), f"lcnn_{epoch}.pt") \ No newline at end of file