Skip to content

Commit

Permalink
Created using Colaboratory
Browse files Browse the repository at this point in the history
  • Loading branch information
rajlm10 committed Mar 8, 2022
1 parent 96b4be2 commit a5a8417
Showing 1 changed file with 304 additions and 0 deletions.
304 changes: 304 additions & 0 deletions D2L_MLP.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,304 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "D2L_MLP.ipynb",
"provenance": [],
"authorship_tag": "ABX9TyPaWNYHcw2sEl8DiZbEhHy4",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/github/rajlm10/D2L-Torch/blob/main/D2L_MLP.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {
"id": "_xA_1Uqw0Cem"
},
"outputs": [],
"source": [
"import torch\n",
"import torchvision\n",
"from torch.utils import data\n",
"from torchvision import transforms\n",
"from torch import nn"
]
},
{
"cell_type": "code",
"source": [
"def get_fashion_mnist_labels(labels): \n",
" \"\"\"Return text labels for the Fashion-MNIST dataset.\"\"\" \n",
" text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat','sandal', 'shirt', 'sneaker', 'bag', 'ankle boot'] \n",
" return [text_labels[int(i)] for i in labels]"
],
"metadata": {
"id": "nrS71wKv0O_R"
},
"execution_count": 43,
"outputs": []
},
{
"cell_type": "code",
"source": [
"#Dummy accuracy\n",
"def accuracy(y_hat, y):\n",
" \"\"\"Compute the number of correct predictions.\"\"\" \n",
" if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:\n",
" y_hat = y_hat.argmax(axis=1)\n",
" cmp = y_hat.type(y.dtype) == y\n",
" return float(cmp.type(y.dtype).sum())"
],
"metadata": {
"id": "FETTp0tN0Q_m"
},
"execution_count": 44,
"outputs": []
},
{
"cell_type": "code",
"source": [
"class Accumulator: \n",
" \"\"\"For accumulating sums over `n` variables.\"\"\" \n",
" def __init__(self, n):\n",
" self.data = [0.0] * n \n",
" \n",
" def add(self, *args):\n",
" self.data = [a + float(b) for a, b in zip(self.data, args)] \n",
" \n",
" def reset(self):\n",
" self.data = [0.0] * len(self.data)\n",
" \n",
" def __getitem__(self, idx): \n",
" return self.data[idx]"
],
"metadata": {
"id": "yK3r2rx_0hIC"
},
"execution_count": 45,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def evaluate_accuracy(net,test_iter):\n",
" \"\"\"Compute the accuracy for a model on a dataset.\"\"\"\n",
" if isinstance(net,torch.nn.Module):\n",
" net.eval()\n",
" metric=Accumulator(2) #no of correct preds, no of predictions\n",
"\n",
" with torch.no_grad():\n",
" for X, y in test_iter:\n",
" metric.add(accuracy(net(X), y), y.numel()) \n",
" return metric[0] / metric[1]\n",
"\n"
],
"metadata": {
"id": "uK5wjHz70Y5Z"
},
"execution_count": 46,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def get_workers():\n",
" return 2"
],
"metadata": {
"id": "aYzjHyVS0uMb"
},
"execution_count": 47,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def load_fashion_mnist(batch_size,resize=None):\n",
" \"\"\"Download the Fashion-MNIST dataset and then load it into memory.\"\"\"\n",
" trans=[transforms.ToTensor()] #PIL image to tensor (normalized between 0-1)\n",
" if resize:\n",
" trans.insert(0,transforms.Resize(resize))\n",
" \n",
" trans=transforms.Compose(trans) #Chains together transforms\n",
"\n",
" mnist_train=torchvision.datasets.FashionMNIST(root=\"../data\", train=True, transform=trans, download=True)\n",
" mnist_test=torchvision.datasets.FashionMNIST(root=\"../data\", train=False, transform=trans, download=True)\n",
"\n",
" return data.DataLoader(mnist_train,batch_size,shuffle=True,num_workers=get_workers()),data.DataLoader(mnist_test,batch_size,shuffle=True,num_workers=get_workers())\n"
],
"metadata": {
"id": "GzE4J7dV1feS"
},
"execution_count": 48,
"outputs": []
},
{
"cell_type": "code",
"source": [
"batch_size = 256\n",
"train_iter, test_iter = load_fashion_mnist(batch_size)"
],
"metadata": {
"id": "VRLmb50v2RCE"
},
"execution_count": 49,
"outputs": []
},
{
"cell_type": "code",
"source": [
"from torch.nn.modules.activation import ReLU\n",
"\n",
"net=nn.Sequential(\n",
" nn.Flatten(),\n",
" nn.Linear(784,256),\n",
" nn.ReLU(),\n",
" nn.Linear(256,10)\n",
")\n",
"\n",
"def init_weights(layer):\n",
" if type(layer) == nn.Linear:\n",
" nn.init.normal_(layer.weight,std=0.01)\n",
"\n",
"net.apply(init_weights)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "tVh2K45o2SB6",
"outputId": "3882bcd1-2d9f-4559-b2e8-e3e08e876fea"
},
"execution_count": 50,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"Sequential(\n",
" (0): Flatten(start_dim=1, end_dim=-1)\n",
" (1): Linear(in_features=784, out_features=256, bias=True)\n",
" (2): ReLU()\n",
" (3): Linear(in_features=256, out_features=10, bias=True)\n",
")"
]
},
"metadata": {},
"execution_count": 50
}
]
},
{
"cell_type": "code",
"source": [
"batch_size, lr, num_epochs = 256, 0.1, 10\n",
"loss = nn.CrossEntropyLoss(reduction='none') \n",
"trainer = torch.optim.SGD(net.parameters(), lr=lr)"
],
"metadata": {
"id": "OObH5AB028nT"
},
"execution_count": 51,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def train_epoch(net,training_set,loss,optimizer):\n",
" #Note training set is an iterator\n",
" if isinstance(net,torch.nn.Module):\n",
" net.train()\n",
"\n",
" metric=Accumulator(3) #stores sum of training loss, sum of training accuracy, no. of examples\n",
"\n",
" for X,y in training_set:\n",
" y_hat=net(X) # n X 10\n",
" l=loss(y_hat,y) # nX10, nX1 -> nX1\n",
"\n",
" if isinstance(optimizer,torch.optim.Optimizer):\n",
" optimizer.zero_grad()\n",
" l.mean().backward()\n",
" optimizer.step()\n",
"\n",
" metric.add(float(l.sum()),accuracy(y_hat,y),y.shape[0])\n",
"\n",
" return metric[0]/metric[2], metric[1]/metric[2]\n",
"\n",
"\n",
"\n"
],
"metadata": {
"id": "vOnlRtrJ8Ubs"
},
"execution_count": 52,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def train(net,training_set,test_set,loss,optimizer,num_epochs):\n",
" for epoch in range(num_epochs):\n",
" train_loss,train_acc=train_epoch(net,training_set,loss,optimizer)\n",
" test_acc = evaluate_accuracy(net, test_set)\n",
"\n",
" print(f'''epoch {epoch+1}: Train Loss: {train_loss},Train Acc: {train_acc}, Test Acc: {test_acc}''')"
],
"metadata": {
"id": "35lXghyO4mXP"
},
"execution_count": 53,
"outputs": []
},
{
"cell_type": "code",
"source": [
"train(net,train_iter,test_iter,loss,trainer,10)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "H2pV-rMa5Npf",
"outputId": "531198ac-e9ac-4b59-9724-8048f6fefcc5"
},
"execution_count": 54,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"epoch 1: Train Loss: 1.0462628690083822,Train Acc: 0.6399666666666667, Test Acc: 0.7253\n",
"epoch 2: Train Loss: 0.5954515837351481,Train Acc: 0.7911333333333334, Test Acc: 0.7736\n",
"epoch 3: Train Loss: 0.5227659240722656,Train Acc: 0.8171, Test Acc: 0.801\n",
"epoch 4: Train Loss: 0.48079511960347493,Train Acc: 0.8312833333333334, Test Acc: 0.8302\n",
"epoch 5: Train Loss: 0.4560855724334717,Train Acc: 0.8404333333333334, Test Acc: 0.8244\n",
"epoch 6: Train Loss: 0.43415399583180747,Train Acc: 0.8482833333333333, Test Acc: 0.8348\n",
"epoch 7: Train Loss: 0.42083296286265054,Train Acc: 0.8528666666666667, Test Acc: 0.8248\n",
"epoch 8: Train Loss: 0.40448241259257,Train Acc: 0.8577, Test Acc: 0.8477\n",
"epoch 9: Train Loss: 0.3944818537394206,Train Acc: 0.8616666666666667, Test Acc: 0.8496\n",
"epoch 10: Train Loss: 0.3821519070943197,Train Acc: 0.8643833333333333, Test Acc: 0.8415\n"
]
}
]
}
]
}

0 comments on commit a5a8417

Please sign in to comment.