Skip to content

Commit

Permalink
Added FDM loss for time into loss function/ fixed errors in FNO module
Browse files Browse the repository at this point in the history
  • Loading branch information
neelsankaran committed Aug 18, 2023
1 parent 9a05bc2 commit 23d4941
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/models/PINO_util/fno_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def forward(self, x, index=0, output_shape = None, default_render = None):
if self.norm is not None:
x = self.norm[self.n_norms*index](x)
if default_render is not None:
default_render = self.norm[self.n_norms*index](x)
default_render = self.norm[self.n_norms*index](default_render)

x_skip_fno = self.fno_skips[index](x)
if default_render is not None:
Expand Down
9 changes: 7 additions & 2 deletions src/op_lib/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import math
import torch
import torch.nn.functional as F
from neuralop.layers.resample import resample

class LpLoss(object):
def __init__(self, d=1, p=2, L=2*math.pi, reduce_dims=0, reductions='sum', add_PDE_LOSS = False):
Expand Down Expand Up @@ -182,7 +183,7 @@ def rel(self, x, y, h=None):
def __call__(self, x, y, h=None):
return self.rel(x, y, h=h)

def temp_stokes_loss(T, u, v):
def temp_stokes_loss2D(T, u, v, T_prev, resolution_scaling, dt):
batchsize = T.size(0)
nx = T.size(2)
ny = T.size(3)
Expand All @@ -191,6 +192,8 @@ def temp_stokes_loss(T, u, v):
T = T.reshape(batchsize, nx, ny)
u = u.reshape(batchsize, nx, ny)
v = v.reshape(batchsize, nx, ny)
if T_prev.size(-2) != T.size(-2) or T_prev.size(-1) != T.size(-1):
T_prev = resample(T_prev, resolution_scaling, [-2, -1], output_shape=T.shape)

T_h = torch.fft.fft2(T, dim=[-2, -1])
u_h = torch.fft.fft2(u, dim=[-2, -1])
Expand Down Expand Up @@ -220,5 +223,7 @@ def temp_stokes_loss(T, u, v):
gradTdotu = torch.fft.irfft2(gradTdotu_h[:, :, :k_maxy + 1], dim=[-2, -1])
Tlap = torch.fft.irfft2(Tlap_h[:, :, :k_maxy+1], dim=[-2,-1])

PDE_LOSS = torch.sum(torch.square(gradTdotu - Tlap))
Tdt = (T-T_prev)/dt

PDE_LOSS = torch.sum(torch.square(Tdt + gradTdotu - Tlap))
return PDE_LOSS
2 changes: 1 addition & 1 deletion src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import matplotlib.pyplot as plt
import numpy as np
from neuralop.models import UNO
from ..models.PINO_util import FNO
from ..models.PINO_util.fno import FNO
from pathlib import Path
import os
import time
Expand Down

0 comments on commit 23d4941

Please sign in to comment.