Skip to content

Commit

Permalink
shown calculation parameters corrected
Browse files Browse the repository at this point in the history
  • Loading branch information
saliherdemk committed Oct 14, 2024
1 parent 0027ff2 commit 5a6f5e7
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 17 deletions.
18 changes: 12 additions & 6 deletions js/Draw/MLP/MLPView.js
Original file line number Diff line number Diff line change
Expand Up @@ -179,16 +179,23 @@ class MlpView extends Playable {
};
}

setCalculationData() {
const layers = this.origin.layers;
const slicedDataAll = [];
for (let i = 0; i < layers.length; i++) {
const { weights, biases, z, outputs } = layers[i];
slicedDataAll.push(this.sliceData(weights, biases, z, outputs));
}
this.calculationComponent?.setData(slicedDataAll);
}

updateParameters() {
const layersElements = this.getAllParameters();
const layers = this.origin.layers;
const slicedDataAll = [];

for (let i = 0; i < layers.length; i++) {
const { weights, biases, z, outputs } = layers[i];
if (this.calculationComponent) {
slicedDataAll.push(this.sliceData(weights, biases, z, outputs));
}
const { weights, biases, outputs } = layers[i];

const { lines, neurons } = layersElements[i];
for (let i = 0; i < neurons.length; i++) {
neurons[i].setOutput(outputs.data[0][i]);
Expand All @@ -203,7 +210,6 @@ class MlpView extends Playable {
}
}
}
this.calculationComponent?.setData(slicedDataAll);
}

setOrigin(obj) {
Expand Down
4 changes: 3 additions & 1 deletion js/Draw/MLP/Playable.js
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,9 @@ class Playable extends Draggable {
[parseInt(this.batchSize), this.getInput().shape[1]],
);
const outputData = this.getOutput()?.getData() ?? null;
await this.origin.trainOneStep(inputData, outputData);
const mlp_output = await this.origin.forward(inputData);
this.calculationComponent && this.setCalculationData();
await this.origin.backward(mlp_output, outputData);
this.updateParameters();
this.graphComponent?.setData(this.origin.getLossData());
this.setMsPerStepText(performance.now() - startTime + "ms / step");
Expand Down
19 changes: 9 additions & 10 deletions js/MLP/Mlp.js
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,6 @@ class MLP extends MlpParams {
}
}

forward(inputs) {
for (let i = 0; i < this.layers.length; i++) {
inputs = this.layers[i].forward(inputs);
}
return inputs;
}

getParameters() {
const allWeights = [];
const allBiases = [];
Expand Down Expand Up @@ -50,10 +43,16 @@ class MLP extends MlpParams {
});
}

trainOneStep(x_batch, y_batch) {
const mlp_output = this.forward(new Tensor(x_batch));
if (this.mode == "eval") return;
forward(x_batch) {
let inputs = new Tensor(x_batch);
for (let i = 0; i < this.layers.length; i++) {
inputs = this.layers[i].forward(inputs);
}
return inputs;
}

backward(mlp_output, y_batch) {
if (this.mode == "eval") return;
const loss = errFuncManager.getFunction(this.errFunc)(
mlp_output,
new Tensor(y_batch),
Expand Down

0 comments on commit 5a6f5e7

Please sign in to comment.