Skip to content

Commit

Permalink
Added pytorch and pybuda mnist training scripts. (#36)
Browse files Browse the repository at this point in the history
  • Loading branch information
vladimirjovanovicTT authored Aug 9, 2024
1 parent 48db3d3 commit 986a7b2
Show file tree
Hide file tree
Showing 3 changed files with 211 additions and 0 deletions.
127 changes: 127 additions & 0 deletions pybuda/test/mlir/mnist/training/mnist_linear_pybuda.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import torch
from torchvision import datasets, transforms
from torch.utils.tensorboard import SummaryWriter

import pybuda
from pybuda import (
CPUDevice,
PyTorchModule,
)
from utils import (
MNISTLinear,
Identity,
load_tb_writer,
load_dataset,
)
from pybuda.config import _get_global_compiler_config

class FeedForward(torch.nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(FeedForward, self).__init__()
self.fc1 = torch.nn.Linear(input_size, hidden_size)
self.relu = torch.nn.ReLU()
self.fc2 = torch.nn.Linear(hidden_size, output_size)

def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x

def train(loss_on_cpu=True):
torch.manual_seed(777)
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
transforms.Lambda(lambda x: x.view(-1))
])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)

writer = SummaryWriter()

num_epochs = 2
input_size = 784
hidden_size = 256
output_size = 10
batch_size = 3
learning_rate = 0.001
sequential = True

framework_model = FeedForward(input_size, hidden_size, output_size)
tt_model = pybuda.PyTorchModule(f"mnist_linear_{batch_size}", framework_model)
tt_optimizer = pybuda.optimizers.SGD(
learning_rate=learning_rate, device_params=True
)
tt0 = pybuda.TTDevice("tt0", module=tt_model, optimizer=tt_optimizer)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
# Dataset sample input
first_sample = test_loader.dataset[0]
sample_input = (first_sample[0].repeat(1, batch_size, 1),)
sample_target = (
torch.nn.functional.one_hot(torch.tensor(first_sample[1]), num_classes=output_size)
.float()
.repeat(1, batch_size, 1)
)

if loss_on_cpu:
cpu0 = CPUDevice("cpu0", module=PyTorchModule("identity", Identity()))
cpu0.place_loss_module(pybuda.PyTorchModule(f"loss_{batch_size}", torch.nn.CrossEntropyLoss()))
else:
tt_loss = pybuda.PyTorchModule(f"loss_{batch_size}", torch.nn.CrossEntropyLoss())
tt0.place_loss_module(tt_loss)

compiler_cfg = _get_global_compiler_config()
compiler_cfg.enable_auto_fusing = False

if not loss_on_cpu:
sample_target = (sample_target,)

checkpoint_queue = pybuda.initialize_pipeline(
training=True,
sample_inputs=sample_input,
sample_targets=sample_target,
_sequential=sequential,
)

best_accuracy = 0.0
best_checkpoint = None

for epoch in range(num_epochs):
for batch_idx, (images, labels) in enumerate(train_loader):

images = (images.unsqueeze(0),)
tt0.push_to_inputs(images)

targets = (
torch.nn.functional.one_hot(labels, num_classes=output_size)
.float()
.unsqueeze(0)
)
if loss_on_cpu:
cpu0.push_to_target_inputs(targets)
else:
tt0.push_to_target_inputs(targets)

pybuda.run_forward(input_count=1, _sequential=sequential)
pybuda.run_backward(input_count=1, zero_grad=True, _sequential=sequential)
pybuda.run_optimizer(checkpoint=True, _sequential=sequential)

loss_q = pybuda.run.get_loss_queue()

step = 0
loss = loss_q.get()[0]
print(loss)
# while not loss_q.empty():
# if loss_on_cpu:
# writer.add_scalar("Loss/PyBuda/overfit", loss_q.get()[0], step)
# else:
# writer.add_scalar("Loss/PyBuda/overfit", loss_q.get()[0].value()[0], step)
# step += 1

writer.close()

if __name__ == "__main__":
train()
84 changes: 84 additions & 0 deletions pybuda/test/mlir/mnist/training/mnist_linear_pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import torch
from torchvision import datasets, transforms
from torch.utils.tensorboard import SummaryWriter

class FeedForward(torch.nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(FeedForward, self).__init__()
self.fc1 = torch.nn.Linear(input_size, hidden_size)
self.relu = torch.nn.ReLU()
self.fc2 = torch.nn.Linear(hidden_size, output_size)

def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x

def train():
torch.manual_seed(777)
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
transforms.Lambda(lambda x: x.view(-1))
])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)

writer = SummaryWriter()

num_epochs = 10
input_size = 784
hidden_size = 256
output_size = 10
model = FeedForward(input_size, hidden_size, output_size)

optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=16, shuffle=False)

best_accuracy = 0.0
best_checkpoint = None

for epoch in range(num_epochs):
for batch_idx, (images, labels) in enumerate(train_loader):
outputs = model(images)
loss = torch.nn.CrossEntropyLoss()(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()

if (batch_idx+1) % 100 == 0:
print(f'Epoch [{epoch+1}/{num_epochs}], Step [{batch_idx+1}/{len(train_loader)}], Loss: {loss.item():.4f}')
writer.add_scalar('Loss/train', loss.item(), epoch * len(train_loader) + batch_idx)

total_correct = 0
total_samples = 0
with torch.no_grad():
for images, labels in test_loader:
outputs = model(images)
_, predicted = torch.max(outputs, dim=1)
total_samples += labels.size(0)
total_correct += (predicted == labels).sum().item()

accuracy = 100.0 * total_correct / total_samples
print(f'Epoch [{epoch+1}/{num_epochs}], Validation Accuracy: {accuracy:.2f}%')

if accuracy > best_accuracy:
best_accuracy = accuracy
best_checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'accuracy': accuracy
}

if best_checkpoint is not None:
model.load_state_dict(best_checkpoint['model_state_dict'])
optimizer.load_state_dict(best_checkpoint['optimizer_state_dict'])
print(f'Reverted to checkpoint with highest validation accuracy: {best_checkpoint["accuracy"]:.2f}%')

writer.close()

if __name__ == "__main__":
train()
File renamed without changes.

0 comments on commit 986a7b2

Please sign in to comment.