Skip to content

Commit

Permalink
Change copy functions over to use (de)serialization feature
Browse files Browse the repository at this point in the history
The (de)serialzation functions already have all the code needed to make a deep copy
of the NeuralNetwork. This commit also adds a test for the `Matrix.copy()` function.
This also copies the learning rate and if CodingTrain#97 is merged the activation function is
copied as well.
  • Loading branch information
enginefeeder101 committed Mar 10, 2018
1 parent 846a259 commit 94eb331
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 39 deletions.
8 changes: 1 addition & 7 deletions lib/matrix.js
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,7 @@ class Matrix {
}

copy() {
let m = new Matrix(this.rows, this.cols);
for (let i = 0; i < this.rows; i++) {
for (let j = 0; j < this.cols; j++) {
m.data[i][j] = this.data[i][j];
}
}
return m;
return Matrix.deserialize(this.serialize());
}

static fromArray(arr) {
Expand Down
14 changes: 14 additions & 0 deletions lib/matrix.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,7 @@ test('static map with row and column params', () => {
]
});
});

test('matrix (de)serialization', () => {
let m = new Matrix(5, 5);
m.randomize();
Expand All @@ -375,3 +376,16 @@ test('matrix (de)serialization', () => {
data: m.data
});
});

test('matrix copy', () => {
let m = new Matrix(5, 5);
m.randomize();

let n = m.copy();

expect(n).toEqual({
rows: m.rows,
cols: m.cols,
data: m.data
});
});
47 changes: 15 additions & 32 deletions lib/nn.js
Original file line number Diff line number Diff line change
Expand Up @@ -19,39 +19,22 @@ let tanh = new ActivationFunction(


class NeuralNetwork {
// TODO: document what a, b, c are
constructor(a, b, c) {
if (a instanceof NeuralNetwork) {
this.input_nodes = a.input_nodes;
this.hidden_nodes = a.hidden_nodes;
this.output_nodes = a.output_nodes;

this.weights_ih = a.weights_ih.copy();
this.weights_ho = a.weights_ho.copy();

this.bias_h = a.bias_h.copy();
this.bias_o = a.bias_o.copy();
} else {
this.input_nodes = a;
this.hidden_nodes = b;
this.output_nodes = c;

this.weights_ih = new Matrix(this.hidden_nodes, this.input_nodes);
this.weights_ho = new Matrix(this.output_nodes, this.hidden_nodes);
this.weights_ih.randomize();
this.weights_ho.randomize();

this.bias_h = new Matrix(this.hidden_nodes, 1);
this.bias_o = new Matrix(this.output_nodes, 1);
this.bias_h.randomize();
this.bias_o.randomize();
}

// TODO: copy these as well
constructor(input_nodes, hidden_nodes, output_nodes) {
this.input_nodes = input_nodes;
this.hidden_nodes = hidden_nodes;
this.output_nodes = output_nodes;

this.weights_ih = new Matrix(this.hidden_nodes, this.input_nodes);
this.weights_ho = new Matrix(this.output_nodes, this.hidden_nodes);
this.weights_ih.randomize();
this.weights_ho.randomize();

this.bias_h = new Matrix(this.hidden_nodes, 1);
this.bias_o = new Matrix(this.output_nodes, 1);
this.bias_h.randomize();
this.bias_o.randomize();
this.setLearningRate();
this.setActivationFunction();


}

predict(input_array) {
Expand Down Expand Up @@ -158,7 +141,7 @@ class NeuralNetwork {

// Adding function for neuro-evolution
copy() {
return new NeuralNetwork(this);
return NeuralNetwork.deserialize(this.serialize());
}

mutate(rate) {
Expand Down

0 comments on commit 94eb331

Please sign in to comment.