diff --git a/pybuda/test/mlir/mnist/training/mnist_linear_pybuda.py b/pybuda/test/mlir/mnist/training/mnist_linear_pybuda.py new file mode 100644 index 000000000..eb0364aae --- /dev/null +++ b/pybuda/test/mlir/mnist/training/mnist_linear_pybuda.py @@ -0,0 +1,127 @@ +import torch +from torchvision import datasets, transforms +from torch.utils.tensorboard import SummaryWriter + +import pybuda +from pybuda import ( + CPUDevice, + PyTorchModule, +) +from utils import ( + MNISTLinear, + Identity, + load_tb_writer, + load_dataset, +) +from pybuda.config import _get_global_compiler_config + +class FeedForward(torch.nn.Module): + def __init__(self, input_size, hidden_size, output_size): + super(FeedForward, self).__init__() + self.fc1 = torch.nn.Linear(input_size, hidden_size) + self.relu = torch.nn.ReLU() + self.fc2 = torch.nn.Linear(hidden_size, output_size) + + def forward(self, x): + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + return x + +def train(loss_on_cpu=True): + torch.manual_seed(777) + transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)), + transforms.Lambda(lambda x: x.view(-1)) + ]) + train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True) + test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True) + + writer = SummaryWriter() + + num_epochs = 2 + input_size = 784 + hidden_size = 256 + output_size = 10 + batch_size = 3 + learning_rate = 0.001 + sequential = True + + framework_model = FeedForward(input_size, hidden_size, output_size) + tt_model = pybuda.PyTorchModule(f"mnist_linear_{batch_size}", framework_model) + tt_optimizer = pybuda.optimizers.SGD( + learning_rate=learning_rate, device_params=True + ) + tt0 = pybuda.TTDevice("tt0", module=tt_model, optimizer=tt_optimizer) + + train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True) + test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False) + # Dataset sample input + first_sample = test_loader.dataset[0] + sample_input = (first_sample[0].repeat(1, batch_size, 1),) + sample_target = ( + torch.nn.functional.one_hot(torch.tensor(first_sample[1]), num_classes=output_size) + .float() + .repeat(1, batch_size, 1) + ) + + if loss_on_cpu: + cpu0 = CPUDevice("cpu0", module=PyTorchModule("identity", Identity())) + cpu0.place_loss_module(pybuda.PyTorchModule(f"loss_{batch_size}", torch.nn.CrossEntropyLoss())) + else: + tt_loss = pybuda.PyTorchModule(f"loss_{batch_size}", torch.nn.CrossEntropyLoss()) + tt0.place_loss_module(tt_loss) + + compiler_cfg = _get_global_compiler_config() + compiler_cfg.enable_auto_fusing = False + + if not loss_on_cpu: + sample_target = (sample_target,) + + checkpoint_queue = pybuda.initialize_pipeline( + training=True, + sample_inputs=sample_input, + sample_targets=sample_target, + _sequential=sequential, + ) + + best_accuracy = 0.0 + best_checkpoint = None + + for epoch in range(num_epochs): + for batch_idx, (images, labels) in enumerate(train_loader): + + images = (images.unsqueeze(0),) + tt0.push_to_inputs(images) + + targets = ( + torch.nn.functional.one_hot(labels, num_classes=output_size) + .float() + .unsqueeze(0) + ) + if loss_on_cpu: + cpu0.push_to_target_inputs(targets) + else: + tt0.push_to_target_inputs(targets) + + pybuda.run_forward(input_count=1, _sequential=sequential) + pybuda.run_backward(input_count=1, zero_grad=True, _sequential=sequential) + pybuda.run_optimizer(checkpoint=True, _sequential=sequential) + + loss_q = pybuda.run.get_loss_queue() + + step = 0 + loss = loss_q.get()[0] + print(loss) + # while not loss_q.empty(): + # if loss_on_cpu: + # writer.add_scalar("Loss/PyBuda/overfit", loss_q.get()[0], step) + # else: + # writer.add_scalar("Loss/PyBuda/overfit", loss_q.get()[0].value()[0], step) + # step += 1 + + writer.close() + +if __name__ == "__main__": + train() diff --git a/pybuda/test/mlir/mnist/training/mnist_linear_pytorch.py b/pybuda/test/mlir/mnist/training/mnist_linear_pytorch.py new file mode 100644 index 000000000..7958745e4 --- /dev/null +++ b/pybuda/test/mlir/mnist/training/mnist_linear_pytorch.py @@ -0,0 +1,84 @@ +import torch +from torchvision import datasets, transforms +from torch.utils.tensorboard import SummaryWriter + +class FeedForward(torch.nn.Module): + def __init__(self, input_size, hidden_size, output_size): + super(FeedForward, self).__init__() + self.fc1 = torch.nn.Linear(input_size, hidden_size) + self.relu = torch.nn.ReLU() + self.fc2 = torch.nn.Linear(hidden_size, output_size) + + def forward(self, x): + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + return x + +def train(): + torch.manual_seed(777) + transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)), + transforms.Lambda(lambda x: x.view(-1)) + ]) + train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True) + test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True) + + writer = SummaryWriter() + + num_epochs = 10 + input_size = 784 + hidden_size = 256 + output_size = 10 + model = FeedForward(input_size, hidden_size, output_size) + + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True) + test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=16, shuffle=False) + + best_accuracy = 0.0 + best_checkpoint = None + + for epoch in range(num_epochs): + for batch_idx, (images, labels) in enumerate(train_loader): + outputs = model(images) + loss = torch.nn.CrossEntropyLoss()(outputs, labels) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + if (batch_idx+1) % 100 == 0: + print(f'Epoch [{epoch+1}/{num_epochs}], Step [{batch_idx+1}/{len(train_loader)}], Loss: {loss.item():.4f}') + writer.add_scalar('Loss/train', loss.item(), epoch * len(train_loader) + batch_idx) + + total_correct = 0 + total_samples = 0 + with torch.no_grad(): + for images, labels in test_loader: + outputs = model(images) + _, predicted = torch.max(outputs, dim=1) + total_samples += labels.size(0) + total_correct += (predicted == labels).sum().item() + + accuracy = 100.0 * total_correct / total_samples + print(f'Epoch [{epoch+1}/{num_epochs}], Validation Accuracy: {accuracy:.2f}%') + + if accuracy > best_accuracy: + best_accuracy = accuracy + best_checkpoint = { + 'epoch': epoch, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'accuracy': accuracy + } + + if best_checkpoint is not None: + model.load_state_dict(best_checkpoint['model_state_dict']) + optimizer.load_state_dict(best_checkpoint['optimizer_state_dict']) + print(f'Reverted to checkpoint with highest validation accuracy: {best_checkpoint["accuracy"]:.2f}%') + + writer.close() + +if __name__ == "__main__": + train() diff --git a/pybuda/test/mlir/mnist/test_training.py b/pybuda/test/mlir/mnist/training/test_training.py similarity index 100% rename from pybuda/test/mlir/mnist/test_training.py rename to pybuda/test/mlir/mnist/training/test_training.py