Skip to content

Commit

Permalink
Update crit.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Yalan-Song authored Dec 1, 2023
1 parent bda8a6e commit 20e93e6
Showing 1 changed file with 9 additions and 10 deletions.
19 changes: 9 additions & 10 deletions hydroDL/model/crit.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,18 +192,18 @@ def forward(self, output, target):
class NSELossBatch(torch.nn.Module):
# Same as Fredrick 2019, batch NSE loss
# stdarray: the standard deviation of the runoff for all basins
def __init__(self, stdarray, eps=0.1):
def __init__(self, stdarray, eps=0.1,device = torch.cuda.current_device()):
super(NSELossBatch, self).__init__()
self.std = stdarray
self.eps = eps
self.device = device

def forward(self, output, target, igrid):
nt = target.shape[0]
stdse = np.tile(self.std[igrid].T, (nt, 1))
if torch.cuda.is_available():
stdbatch = torch.tensor(stdse, requires_grad=False).float().cuda()
else:
stdbatch = torch.tensor(stdse, requires_grad=False).float()

tdbatch = torch.tensor(stdse, requires_grad=False).float().to(device)

p0 = output[:, :, 0] # dim: Time*Gage
t0 = target[:, :, 0]
mask = t0 == t0
Expand All @@ -219,21 +219,20 @@ def forward(self, output, target, igrid):
# mask = t0 == t0
# loss = torch.mean(normRes[mask])
return loss

class NSESqrtLossBatch(torch.nn.Module):
# Same as Fredrick 2019, batch NSE loss, use RMSE and STD instead
# stdarray: the standard deviation of the runoff for all basins
def __init__(self, stdarray, eps=0.1):
def __init__(self, stdarray, eps=0.1,device = torch.cuda.current_device()):
super(NSESqrtLossBatch, self).__init__()
self.std = stdarray
self.eps = eps
self.device = device

def forward(self, output, target, igrid):
nt = target.shape[0]
stdse = np.tile(self.std[igrid], (nt, 1))
if torch.cuda.is_available():
stdbatch = torch.tensor(stdse, requires_grad=False).float().cuda()
else:
stdbatch = torch.tensor(stdse, requires_grad=False).float()
tdbatch = torch.tensor(stdse, requires_grad=False).float().to(device)
p0 = output[:, :, 0] # dim: Time*Gage
t0 = target[:, :, 0]
mask = t0 == t0
Expand Down

0 comments on commit 20e93e6

Please sign in to comment.