Skip to content

Commit

Permalink
keras fix layer sharing inputs names
Browse files Browse the repository at this point in the history
  • Loading branch information
Chizkiyahu committed Oct 6, 2023
1 parent d83c4c0 commit c7874c8
Showing 1 changed file with 18 additions and 4 deletions.
22 changes: 18 additions & 4 deletions source/keras.js
Original file line number Diff line number Diff line change
Expand Up @@ -756,20 +756,34 @@ keras.Node = class {
const innerType = this.inner ? this.inner.type : null;
const innerMetadata = innerType ? metadata.type(innerType) : null;
let inputIndex = 0;
const typeInputs = [];
if (this._type && this._type.inputs) {
for (let i = 0; i < this._type.inputs.length; i++) {
if (i === 0 && layer.inbound_nodes) {
for (let j = 0; j < layer.inbound_nodes.length; j++) {
let inputType = JSON.parse(JSON.stringify(this._type.inputs[i]));

Check failure on line 764 in source/keras.js

View workflow job for this annotation

GitHub Actions / Build (ubuntu-latest)

'inputType' is never reassigned. Use 'const' instead
inputType.name = inputType.name + " edge " + j;
typeInputs.push(inputType);
}
} else {
typeInputs.push(this._type.inputs[i]);
}
}
}
while (inputs.length > 0) {
let list = false;
let name = null;
let visible = true;
if (!innerMetadata || inputIndex == 0) {
if (this._type && this._type.inputs && inputIndex < this._type.inputs.length) {
const input = this._type.inputs[inputIndex];
if (this._type && typeInputs && inputIndex < typeInputs.length) {
const input = typeInputs[inputIndex];
name = input.name;
if (type === 'BatchNormalization' && name === 'gamma' && config.scale === false) {
inputIndex++;
continue;
}
visible = input.visible == false ? false : true;
if (this._type.inputs[inputIndex].list) {
if (typeInputs[inputIndex].list) {
list = true;
}
}
Expand Down Expand Up @@ -834,7 +848,7 @@ keras.Node = class {
this._outputs.push(argument);
}

const inputTypes = new Map((this._type.inputs || []).map((input) => [ input.name, input.type ]));
const inputTypes = new Map((typeInputs || []).map((input) => [ input.name, input.type ]));
for (const entry of Object.entries(args)) {
const name = entry[0];
const arg = entry[1];
Expand Down

0 comments on commit c7874c8

Please sign in to comment.