Skip to content

Commit

Permalink
finish compactness+descriptiveness branches
Browse files Browse the repository at this point in the history
  • Loading branch information
Long Nguyen-Vu committed Dec 12, 2023
1 parent a1870cf commit 86fa3eb
Show file tree
Hide file tree
Showing 5 changed files with 5 additions and 278 deletions.
4 changes: 3 additions & 1 deletion models/sslassist.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,7 @@ def __init__(self, args, device):
self.pool_hT2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)

self.out_layer = nn.Linear(5 * gat_dims[1], 2)
self.emb_layer = nn.Linear(160, 100)

def forward(self, x):
#-------pre-trained Wav2vec model fine tunning ------------------------##
Expand Down Expand Up @@ -588,10 +589,11 @@ def forward(self, x):
S_max, _ = torch.max(torch.abs(out_S), dim=1)
S_avg = torch.mean(out_S, dim=1)

emb = last_hidden = torch.cat(
last_hidden = torch.cat(
[T_max, T_avg, S_max, S_avg, master.squeeze(1)], dim=1)

last_hidden = self.drop(last_hidden)
emb = self.emb_layer(last_hidden)
output = self.out_layer(last_hidden)

return emb, output
Expand Down
4 changes: 2 additions & 2 deletions oc_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,8 +377,8 @@ def collate_fn(self, batch):
des = outputs_senet34[1]
# Calculate the losses
# c_loss = euclidean_distance_loss(com)
c_loss = 0.0*compactness_loss(com)
d_loss = 1.0*descriptiveness_loss(des, labels.squeeze(0)) # because labels.shape = torch.Size([1, 8])
c_loss = 0.1*compactness_loss(com)
d_loss = 0.9*descriptiveness_loss(des, labels.squeeze(0)) # because labels.shape = torch.Size([1, 8])
loss = c_loss + d_loss

loss.backward()
Expand Down
146 changes: 0 additions & 146 deletions test_dataloader_v2.py

This file was deleted.

61 changes: 0 additions & 61 deletions test_model_merge.py

This file was deleted.

68 changes: 0 additions & 68 deletions test_sampler.py

This file was deleted.

0 comments on commit 86fa3eb

Please sign in to comment.