From 5a6f5e77a63cebefb2087901ecaad22969ff2281 Mon Sep 17 00:00:00 2001 From: saliherdemk Date: Mon, 14 Oct 2024 18:09:07 +0300 Subject: [PATCH] shown calculation parameters corrected --- js/Draw/MLP/MLPView.js | 18 ++++++++++++------ js/Draw/MLP/Playable.js | 4 +++- js/MLP/Mlp.js | 19 +++++++++---------- 3 files changed, 24 insertions(+), 17 deletions(-) diff --git a/js/Draw/MLP/MLPView.js b/js/Draw/MLP/MLPView.js index a045e9f..c7792eb 100644 --- a/js/Draw/MLP/MLPView.js +++ b/js/Draw/MLP/MLPView.js @@ -179,16 +179,23 @@ class MlpView extends Playable { }; } + setCalculationData() { + const layers = this.origin.layers; + const slicedDataAll = []; + for (let i = 0; i < layers.length; i++) { + const { weights, biases, z, outputs } = layers[i]; + slicedDataAll.push(this.sliceData(weights, biases, z, outputs)); + } + this.calculationComponent?.setData(slicedDataAll); + } + updateParameters() { const layersElements = this.getAllParameters(); const layers = this.origin.layers; - const slicedDataAll = []; for (let i = 0; i < layers.length; i++) { - const { weights, biases, z, outputs } = layers[i]; - if (this.calculationComponent) { - slicedDataAll.push(this.sliceData(weights, biases, z, outputs)); - } + const { weights, biases, outputs } = layers[i]; + const { lines, neurons } = layersElements[i]; for (let i = 0; i < neurons.length; i++) { neurons[i].setOutput(outputs.data[0][i]); @@ -203,7 +210,6 @@ class MlpView extends Playable { } } } - this.calculationComponent?.setData(slicedDataAll); } setOrigin(obj) { diff --git a/js/Draw/MLP/Playable.js b/js/Draw/MLP/Playable.js index d8eff44..eb88bdd 100644 --- a/js/Draw/MLP/Playable.js +++ b/js/Draw/MLP/Playable.js @@ -308,7 +308,9 @@ class Playable extends Draggable { [parseInt(this.batchSize), this.getInput().shape[1]], ); const outputData = this.getOutput()?.getData() ?? null; - await this.origin.trainOneStep(inputData, outputData); + const mlp_output = await this.origin.forward(inputData); + this.calculationComponent && this.setCalculationData(); + await this.origin.backward(mlp_output, outputData); this.updateParameters(); this.graphComponent?.setData(this.origin.getLossData()); this.setMsPerStepText(performance.now() - startTime + "ms / step"); diff --git a/js/MLP/Mlp.js b/js/MLP/Mlp.js index 6a01b78..57fba82 100644 --- a/js/MLP/Mlp.js +++ b/js/MLP/Mlp.js @@ -14,13 +14,6 @@ class MLP extends MlpParams { } } - forward(inputs) { - for (let i = 0; i < this.layers.length; i++) { - inputs = this.layers[i].forward(inputs); - } - return inputs; - } - getParameters() { const allWeights = []; const allBiases = []; @@ -50,10 +43,16 @@ class MLP extends MlpParams { }); } - trainOneStep(x_batch, y_batch) { - const mlp_output = this.forward(new Tensor(x_batch)); - if (this.mode == "eval") return; + forward(x_batch) { + let inputs = new Tensor(x_batch); + for (let i = 0; i < this.layers.length; i++) { + inputs = this.layers[i].forward(inputs); + } + return inputs; + } + backward(mlp_output, y_batch) { + if (this.mode == "eval") return; const loss = errFuncManager.getFunction(this.errFunc)( mlp_output, new Tensor(y_batch),