-
Notifications
You must be signed in to change notification settings - Fork 3
/
losses.py
28 lines (23 loc) · 898 Bytes
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# Differentiable FDN for Colorless Reverberation
# custom loss functions
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class mse_loss(nn.Module):
'''Means squared error between abs(x1) and x2'''
def forward(self, y_pred, y_true):
loss = 0.0
N = y_pred.size(dim=-1)
# loss on channels' output
for i in range(N):
loss = loss + torch.mean(torch.pow(torch.abs(y_pred[:,i])-torch.abs(y_true), 2*torch.ones(y_pred.size(0))))
# loss on system's output
y_pred_sum = torch.sum(y_pred, dim=-1)
loss = loss/N + torch.mean(torch.pow(torch.abs(y_pred_sum)-torch.abs(y_true), 2*torch.ones(y_pred.size(0))))
return loss
class sparsity_loss(nn.Module):
''''''
def forward(self, A):
N = A.shape[-1]
return -(torch.sum(torch.abs(A)) - N)/(N*(np.sqrt(N)-1))