-
Notifications
You must be signed in to change notification settings - Fork 22
General Model
Kevin Lane edited this page Jan 13, 2021
·
2 revisions
class TimeModel(nn.Module)
def forward():
pass
def _eval(self, loader, device):
self.eval()
loss = 0
with torch.no_grad():
for x, y in loader():
x, y = x.to(device), y.to(device)
pred = self(x)
loss += loss_fun(y, pred)
return loss / len(loader.dataset)
def _train(self, loader, device, optim):
self.train()
for x, y in loader:
x, y = x.to(device), y.to(device)
optim.zero_grad()
pred = self(x)
loss = loss_fun(y, pred)
loss.backward()
optim.step()
def fit(self, train_loader, test_loader, device, optim, scheduler, n_epochs):
for epoch in range(n_epochs):
self._train(train_loader, device, optim)
train_loss = self._eval(train_loader, device)
test_loss = self._eval(test_loader, device)
scheduler.step()
def predict(self, x):
pass