-
Notifications
You must be signed in to change notification settings - Fork 0
/
neural-network-mnist.js
33 lines (28 loc) · 989 Bytes
/
neural-network-mnist.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
const brain = require('brain.js');
const mnist = require('mnist'); // Library for loading the MNIST dataset
// Load the MNIST training data
const set = mnist.set(8000, 2000); // 8000 training examples, 2000 testing examples
const trainingData = set.training;
// Create a neural network
const net = new brain.NeuralNetwork({
hiddenLayers: [784, 392] // Two hidden layers with 784 and 392 neurons respectively
});
// Train the neural network
net.train(trainingData, {
log: true,
logPeriod: 1,
learningRate: 0.05,
errorThresh: 0.005
});
// Test the trained network
const testingData = set.test;
let correct = 0;
for (let i = 0; i < testingData.length; i++) {
const output = net.run(testingData[i].input);
const predictedDigit = output.indexOf(Math.max(...output));
const actualDigit = testingData[i].output.indexOf(1);
if (predictedDigit === actualDigit) {
correct++;
}
}
console.log(`Accuracy: ${(correct / testingData.length) * 100}%`);