Skip to content

General Model

Kevin Lane edited this page Jan 13, 2021 · 2 revisions
class TimeModel(nn.Module)
    def forward():
        pass
    
    def _eval(self, loader, device):
        self.eval()
        loss = 0
        
        with torch.no_grad():
            for x, y in loader():
                x, y = x.to(device), y.to(device)
                pred = self(x)
                loss += loss_fun(y, pred)
        
        return loss / len(loader.dataset)
        
    def _train(self, loader, device, optim):
        self.train()
        
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            optim.zero_grad()
            pred = self(x)
            loss = loss_fun(y, pred)
            loss.backward()
            optim.step()

    def fit(self, train_loader, test_loader, device, optim, scheduler, n_epochs):
        for epoch in range(n_epochs):
            self._train(train_loader, device, optim)
            train_loss = self._eval(train_loader, device)
            test_loss = self._eval(test_loader, device)
            scheduler.step()

    def predict(self, x):
        pass
Clone this wiki locally