Skip to content

Commit

Permalink
eval loss graph added
Browse files Browse the repository at this point in the history
  • Loading branch information
saliherdemk committed Oct 16, 2024
1 parent 21346b0 commit b69bb52
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 18 deletions.
11 changes: 9 additions & 2 deletions js/Draw/MLP/MLPView.js
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ class MlpView extends Playable {
this.selected = false;
}

isEval() {
return this.mode == "eval";
}

getAllParameters() {
let layersParameters = [];

Expand Down Expand Up @@ -90,7 +94,10 @@ class MlpView extends Playable {

handleSetMode(mode) {
this.setMode(mode);
this.isInitialized() && this.checkCompleted();
if (this.isInitialized()) {
this.setGraphComponentData();
this.checkCompleted();
}
}

handleSetZenMode(mode) {
Expand Down Expand Up @@ -407,7 +414,7 @@ class MlpView extends Playable {
{ func: "text", args: [`Total Parameters: ${totalParams}\n`, x, y - 10] },
];

const commands = this.getMode() == "train" ? this.getTrainCommands() : [];
const commands = this.isEval() ? [] : this.getTrainCommands();

executeDrawingCommands([...commands, ...common]);
}
Expand Down
18 changes: 12 additions & 6 deletions js/Draw/MLP/Playable.js
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,14 @@ class Playable extends Draggable {
this.x - (graphComponent.w - this.w) / 2,
this.y - 400,
);
graphComponent.setData(this.origin.getLossData());
this.graphComponent = graphComponent;
this.setGraphComponentData();
}

setGraphComponentData() {
this.graphComponent?.setData(
this.isEval() ? this.origin.getEvalLoss() : this.origin.getTrainLoss(),
);
}

removeGraphComponent() {
Expand Down Expand Up @@ -134,13 +140,12 @@ class Playable extends Draggable {
}

checkCompleted() {
const isEval = this.getMode() == "eval";
this.pause();
this.updateStatus(
+(
(this.getInput() instanceof InputLayer ||
this.getInput() instanceof DigitComponent) &&
(isEval || this.getOutput() instanceof OutputLayer)
(this.isEval() || this.getOutput() instanceof OutputLayer)
),
);

Expand Down Expand Up @@ -304,16 +309,17 @@ class Playable extends Draggable {
async executeOnce() {
let startTime = performance.now();
const inputData = this.getInput().getData();
const origin = this.origin;
this.calculationComponent?.setInputData(
inputData.slice(0, 5).map((row) => row.slice(0, 5)),
[parseInt(this.batchSize), this.getInput().shape[1]],
);
const outputData = this.getOutput()?.getData() ?? null;
const mlp_output = await this.origin.forward(inputData);
const mlp_output = await origin.forward(inputData);
this.calculationComponent && this.setCalculationData();
await this.origin.backward(mlp_output, outputData);
outputData && (await origin.backward(mlp_output, outputData));
this.updateParameters();
this.graphComponent?.setData(this.origin.getLossData());
this.setGraphComponentData();
this.setMsPerStepText(performance.now() - startTime + "ms / step");
}

Expand Down
15 changes: 10 additions & 5 deletions js/MLP/Mlp.js
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,15 @@ class MLP extends MlpParams {
}

backward(mlp_output, y_batch) {
if (this.mode == "eval") return;
const loss = errFuncManager.getFunction(this.errFunc)(
mlp_output,
new Tensor(y_batch),
);
this.addLossData(loss.data[0][0]);
if (this.mode == "eval") {
this.addEvalLoss(loss.data[0][0]);
return;
}
this.addTrainLoss(loss.data[0][0]);

this.zeroGrad();
loss.backward();
Expand All @@ -74,14 +77,16 @@ class MLP extends MlpParams {
biases: biases.map((b) => b.data),
seenRecordNum: this.seenRecordNum,
stepCounter: this.stepCounter,
lossData: this.lossData,
trainLoss: this.trainLoss,
evalLoss: this.evalLoss,
};
}

import({ weights, biases, seenRecordNum, stepCounter, lossData }) {
import({ weights, biases, seenRecordNum, stepCounter, trainLoss, evalLoss }) {
this.seenRecordNum = seenRecordNum;
this.stepCounter = stepCounter;
this.lossData = lossData;
this.trainLoss = trainLoss;
this.evalLoss = evalLoss;
for (let i = 0; i < this.layers.length; i++) {
this.layers[i].setParameters(weights[i], biases[i]);
}
Expand Down
19 changes: 14 additions & 5 deletions js/MLP/MlpParams.js
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,24 @@ class MlpParams {
this.epoch = 0;
this.seenRecordNum = 0;
this.mode = "train";
this.lossData = [];
this.trainLoss = [];
this.evalLoss = [];
}

getLossData() {
return this.lossData;
getTrainLoss() {
return this.trainLoss;
}

addLossData(lossData) {
this.lossData.push(lossData);
getEvalLoss() {
return this.evalLoss;
}

addTrainLoss(lossData) {
this.trainLoss.push(lossData);
}

addEvalLoss(lossData) {
this.evalLoss.push(lossData);
}

setLr(lr) {
Expand Down

0 comments on commit b69bb52

Please sign in to comment.