Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update C1_W2_Assignment.js #45

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
142 changes: 93 additions & 49 deletions C1_Browser-based-TF-JS/W2/assignment/C1_W2_Assignment.js
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import {FMnistData} from './fashion-data.js';
import { FMnistData } from './fashion-data.js';

var canvas, ctx, saveButton, clearButton;
var pos = {x:0, y:0};
var pos = { x: 0, y: 0 };
var rawImage;
var model;

function getModel() {

// In the space below create a convolutional neural network that can classify the
// images of articles of clothing in the Fashion MNIST dataset. Your convolutional
// neural network should only use the following layers: conv2d, maxPooling2d,
Expand All @@ -14,48 +15,94 @@ function getModel() {
// many layers, filters, and neurons as you like.
// HINT: Take a look at the MNIST example.
model = tf.sequential();

// YOUR CODE HERE


// Compile the model using the categoricalCrossentropy loss,

// Add the first convolutional layer
model.add(tf.layers.conv2d({
inputShape: [28, 28, 1],
kernelSize: 3,
filters: 32,
activation: 'relu',
kernelInitializer: 'varianceScaling'
}));

// Add a max pooling layer
model.add(tf.layers.maxPooling2d({ poolSize: [2, 2] }));

// Add another convolutional layer
model.add(tf.layers.conv2d({
kernelSize: 3,
filters: 64,
activation: 'relu'
}));

// Add a max pooling layer
model.add(tf.layers.maxPooling2d({ poolSize: [2, 2] }));

// Add a flatten layer
model.add(tf.layers.flatten());

// Add a dense layer
model.add(tf.layers.dense({
units: 128,
activation: 'relu'
}));

// Add the output layer
model.add(tf.layers.dense({
units: 10,
activation: 'softmax'
}));

// Compile the model using categoricalCrossentropy loss,
// the tf.train.adam() optimizer, and `acc` for your metrics.
model.compile(// YOUR CODE HERE);

model.compile({
optimizer: tf.train.adam(),
loss: 'categoricalCrossentropy',
metrics: ['accuracy']
});

return model;
}

async function train(model, data) {

// Set the following metrics for the callback: 'loss', 'val_loss', 'acc', 'val_acc'.
const metrics = // YOUR CODE HERE
const metrics = ['loss', 'val_loss', 'acc', 'val_acc'];


// Create the container for the callback. Set the name to 'Model Training' and
// use a height of 1000px for the styles.
const container = // YOUR CODE HERE


const container = document.getElementById('main');

// Use tfvis.show.fitCallbacks() to setup the callbacks.
// Use the container and metrics defined above as the parameters.
const fitCallbacks = // YOUR CODE HERE
const fitCallbacks = tfvis.show.fitCallbacks(container, metrics);

const BATCH_SIZE = 512;
const TRAIN_DATA_SIZE = 6000;
const TEST_DATA_SIZE = 1000;

// Get the training batches and resize them. Remember to put your code
// inside a tf.tidy() clause to clean up all the intermediate tensors.
// HINT: Take a look at the MNIST example.
const [trainXs, trainYs] = // YOUR CODE HERE
const [trainXs, trainYs] = tf.tidy(() => {
const d = data.nextTrainBatch(TRAIN_DATA_SIZE);
return [
d.xs.reshape([TRAIN_DATA_SIZE, 28, 28, 1]),
d.labels
];
});


// Get the testing batches and resize them. Remember to put your code
// inside a tf.tidy() clause to clean up all the intermediate tensors.
// HINT: Take a look at the MNIST example.
const [testXs, testYs] = // YOUR CODE HERE
const [testXs, testYs] = tf.tidy(() => {
const d = data.nextTestBatch(TEST_DATA_SIZE);
return [
d.xs.reshape([TEST_DATA_SIZE, 28, 28, 1]),
d.labels
];
});


return model.fit(trainXs, trainYs, {
batchSize: BATCH_SIZE,
validationData: [testXs, testYs],
Expand All @@ -65,13 +112,13 @@ async function train(model, data) {
});
}

function setPosition(e){
pos.x = e.clientX-100;
pos.y = e.clientY-100;
function setPosition(e) {
pos.x = e.clientX - 100;
pos.y = e.clientY - 100;
}

function draw(e) {
if(e.buttons!=1) return;
if (e.buttons != 1) return;
ctx.beginPath();
ctx.lineWidth = 24;
ctx.lineCap = 'round';
Expand All @@ -82,56 +129,53 @@ function draw(e) {
ctx.stroke();
rawImage.src = canvas.toDataURL('image/png');
}

function erase() {
ctx.fillStyle = "black";
ctx.fillRect(0,0,280,280);
ctx.fillRect(0, 0, 280, 280);
}

function save() {
var raw = tf.browser.fromPixels(rawImage,1);
var resized = tf.image.resizeBilinear(raw, [28,28]);
var raw = tf.browser.fromPixels(rawImage, 1);
var resized = tf.image.resizeBilinear(raw, [28, 28]);
var tensor = resized.expandDims(0);

var prediction = model.predict(tensor);
var pIndex = tf.argMax(prediction, 1).dataSync();

var classNames = ["T-shirt/top", "Trouser", "Pullover",
"Dress", "Coat", "Sandal", "Shirt",
"Sneaker", "Bag", "Ankle boot"];



var classNames = ["T-shirt/top", "Trouser", "Pullover",
"Dress", "Coat", "Sandal", "Shirt",
"Sneaker", "Bag", "Ankle boot"
];


alert(classNames[pIndex]);
}

function init() {
canvas = document.getElementById('canvas');
rawImage = document.getElementById('canvasimg');
ctx = canvas.getContext("2d");
ctx.fillStyle = "black";
ctx.fillRect(0,0,280,280);
ctx.fillRect(0, 0, 280, 280);
canvas.addEventListener("mousemove", draw);
canvas.addEventListener("mousedown", setPosition);
canvas.addEventListener("mouseenter", setPosition);
saveButton = document.getElementById('sb');
saveButton = document.getElementById('classifyBtn');
saveButton.addEventListener("click", save);
clearButton = document.getElementById('cb');
clearButton = document.getElementById('clearBtn');
clearButton.addEventListener("click", erase);
}


async function run() {
const data = new FMnistData();
await data.load();
const model = getModel();
tfvis.show.modelSummary({name: 'Model Architecture'}, model);
tfvis.show.modelSummary({ name: 'Model Architecture' }, model);
await train(model, data);
await model.save('downloads://my_model');
init();
alert("Training is done, try classifying your drawings!");
}

document.addEventListener('DOMContentLoaded', run);



86 changes: 81 additions & 5 deletions C1_Browser-based-TF-JS/W2/assignment/fashion-mnist.html
Original file line number Diff line number Diff line change
@@ -1,17 +1,93 @@
<html>
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Fashion Classifier</title>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/[email protected]"></script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-vis"></script>

</head>
<body>
<h1>Fashion Classifier!</h1>
<canvas id="canvas" width="280" height="280" style="position:absolute;top:100;left:100;border:8px solid;"></canvas>
<img id="canvasimg" style="position:absolute;top:10%;left:52%;width:280;height:280;display:none;">
<input type="button" value="classify" id="sb" size="48" style="position:absolute;top:400;left:100;">
<input type="button" value="clear" id="cb" size="23" style="position:absolute;top:400;left:180;">
<input type="button" value="Classify" id="classifyBtn" size="48" style="position:absolute;top:400;left:100;">
<input type="button" value="Clear" id="clearBtn" size="23" style="position:absolute;top:400;left:180;">
<div id="main" style="height: 1000px;"></div>
<script src="fashion-data.js" type="module"></script>
<script src="C1_W2_Assignment.js" type="module"></script>
<!-- <script src="C1_W2_Assignment_Solution.js" type="module"></script> -->
<script>
var canvas, ctx, classifyButton, clearButton;
var pos = { x: 0, y: 0 };
var rawImage;
var model;

// Function to initialize the canvas and buttons
function init() {
canvas = document.getElementById('canvas');
rawImage = document.getElementById('canvasimg');
ctx = canvas.getContext("2d");
ctx.fillStyle = "black";
ctx.fillRect(0, 0, 280, 280);
canvas.addEventListener("mousemove", draw);
canvas.addEventListener("mousedown", setPosition);
canvas.addEventListener("mouseenter", setPosition);

// Add event listeners for Classify and Clear buttons
classifyButton = document.getElementById('classifyBtn');
classifyButton.addEventListener("click", classify);

clearButton = document.getElementById('clearBtn');
clearButton.addEventListener("click", erase);
}

// Function to handle mouse movements and draw on the canvas
function draw(e) {
if (e.buttons != 1) return;
ctx.beginPath();
ctx.lineWidth = 24;
ctx.lineCap = 'round';
ctx.strokeStyle = 'white';
ctx.moveTo(pos.x, pos.y);
setPosition(e);
ctx.lineTo(pos.x, pos.y);
ctx.stroke();
rawImage.src = canvas.toDataURL('image/png');
}

// Function to set the position for drawing
function setPosition(e) {
pos.x = e.clientX - 100;
pos.y = e.clientY - 100;
}

// Function to clear the canvas
function erase() {
ctx.fillStyle = "black";
ctx.fillRect(0, 0, 280, 280);
}

// Function to handle classification when Classify button is clicked
function classify() {
// Add your classification logic here
// For example, you can display an alert message
alert("Classifying...");
}

// Function to load the model and initialize the canvas
async function run() {
const data = new FMnistData();
await data.load();
const model = getModel();
tfvis.show.modelSummary({ name: 'Model Architecture' }, model);
await train(model, data);
await model.save('downloads://my_model');
init();
alert("Training is done, try classifying your drawings!");
}

// Run the initialization function when the DOM content is loaded
document.addEventListener('DOMContentLoaded', run);
</script>
</body>
</html>