Skip to content

Commit

Permalink
Faster row/col/depth min indices (firstLastNonZero3D) 200ms -> 50ms
Browse files Browse the repository at this point in the history
  • Loading branch information
neurolabusc committed May 15, 2024
1 parent c3d2964 commit 3e61e69
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 138 deletions.
122 changes: 39 additions & 83 deletions brainchop-mainthread.js
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,43 @@ async function load_model(modelUrl) {
return await tf.loadLayersModel(modelUrl)
}

// return first and last non-zero voxel in row (dim = 0), column (1) or slice (2) dimension
async function firstLastNonZero(tensor3D, dim = 0) {
let mxs = []
if (dim === 0) {
mxs = await tensor3D.max(2).max(1).arraySync()
} else if (dim === 1) {
mxs = await tensor3D.max(2).max(0).arraySync()
} else {
mxs = await tensor3D.max(1).max(0).arraySync()
}
let mn = mxs.length
let mx = 0
for (let i = 0; i < mxs.length; i++) {
if (mxs[i] > 0) {
mn = i
break
}
}
for (let i = mxs.length - 1; i >= 0; i--) {
if (mxs[i] > 0) {
mx = i
break
}
}
return [mn, mx]
}

async function firstLastNonZero3D(tensor3D) {
const [row_min, row_max] = await firstLastNonZero(tensor3D, 0)
const [col_min, col_max] = await firstLastNonZero(tensor3D, 1)
const [depth_min, depth_max] = await firstLastNonZero(tensor3D, 2)
console.log('row min and max :', row_min, row_max)
console.log('col min and max :', col_min, col_max)
console.log('depth min and max :', depth_min, depth_max)
return [row_min, row_max, col_min, col_max, depth_min, depth_max]
}

async function getAllSlicesDataAsTF3D(num_of_slices, niftiHeader, niftiImage) {
// Get nifti dimensions
const cols = niftiHeader.dims[1] // Slice width
Expand Down Expand Up @@ -709,54 +746,14 @@ async function inferenceFullVolumeSeqCovLayerPhase2(
mask_3d = await pipeline1_out.greater([0]).asType('bool')
// -- pipeline1_out.dispose();
}

console.log(' mask_3d shape : ', mask_3d.shape)

const coords = await tf.whereAsync(mask_3d)
// -- Get each voxel coords (x, y, z)

const [row_min, row_max, col_min, col_max, depth_min, depth_max] = await firstLastNonZero3D(mask_3d)
mask_3d.dispose()

const coordsArr = coords.arraySync()

let row_min = slice_height
let row_max = 0
let col_min = slice_width
let col_max = 0
let depth_min = num_of_slices
let depth_max = 0

for (let i = 0; i < coordsArr.length; i++) {
if (row_min > coordsArr[i][0]) {
row_min = coordsArr[i][0]
} else if (row_max < coordsArr[i][0]) {
row_max = coordsArr[i][0]
}

if (col_min > coordsArr[i][1]) {
col_min = coordsArr[i][1]
} else if (col_max < coordsArr[i][1]) {
col_max = coordsArr[i][1]
}

if (depth_min > coordsArr[i][2]) {
depth_min = coordsArr[i][2]
} else if (depth_max < coordsArr[i][2]) {
depth_max = coordsArr[i][2]
}
}

console.log('row min and max :', row_min, row_max)
console.log('col min and max :', col_min, col_max)
console.log('depth min and max :', depth_min, depth_max)

// -- Reference voxel that cropped volume started slice with it
const refVoxel = [row_min, col_min, depth_min]
// -- Starting form refVoxel, size of bounding volume
const boundVolSizeArr = [row_max - row_min + 1, col_max - col_min + 1, depth_max - depth_min + 1]

coords.dispose()

// -- Extract 3d object (e.g. brain)
const cropped_slices_3d = await slices_3d.slice(
[row_min, col_min, depth_min],
Expand Down Expand Up @@ -865,10 +862,6 @@ async function inferenceFullVolumeSeqCovLayerPhase2(

const curTensor = []
curTensor[0] = await cropped_slices_3d_w_pad.reshape(adjusted_input_shape)
// console.log("curTensor[0] :", curTensor[0].dataSync());

// let curProgBar = parseInt(document.getElementById("progressBar").style.width);

const timer = window.setInterval(async function () {
try {
if (res.layers[i].activation.getClassName() !== 'linear') {
Expand Down Expand Up @@ -1071,7 +1064,6 @@ async function inferenceFullVolumeSeqCovLayerPhase2(
callbackUI(unreliableReasons, NaN, unreliableReasons)
}
}
// });
}

async function inferenceFullVolumePhase2(
Expand Down Expand Up @@ -1121,42 +1113,8 @@ async function inferenceFullVolumePhase2(
// -- pipeline1_out.dispose()
}
console.log(' mask_3d shape : ', mask_3d.shape)
const coords = await tf.whereAsync(mask_3d)
// -- Get each voxel coords (x, y, z)
const [row_min, row_max, col_min, col_max, depth_min, depth_max] = await firstLastNonZero3D(mask_3d)
mask_3d.dispose()
const coordsArr = coords.arraySync()

let row_min = slice_height
let row_max = 0
let col_min = slice_width
let col_max = 0
let depth_min = num_of_slices
let depth_max = 0

for (let i = 0; i < coordsArr.length; i++) {
if (row_min > coordsArr[i][0]) {
row_min = coordsArr[i][0]
} else if (row_max < coordsArr[i][0]) {
row_max = coordsArr[i][0]
}

if (col_min > coordsArr[i][1]) {
col_min = coordsArr[i][1]
} else if (col_max < coordsArr[i][1]) {
col_max = coordsArr[i][1]
}

if (depth_min > coordsArr[i][2]) {
depth_min = coordsArr[i][2]
} else if (depth_max < coordsArr[i][2]) {
depth_max = coordsArr[i][2]
}
}

console.log('row min and max :', row_min, row_max)
console.log('col min and max :', col_min, col_max)
console.log('depth min and max :', depth_min, depth_max)

// -- Reference voxel that cropped volume started slice with it
const refVoxel = [row_min, col_min, depth_min]
console.log('refVoxel :', refVoxel)
Expand All @@ -1166,8 +1124,6 @@ async function inferenceFullVolumePhase2(

console.log('boundVolSizeArr :', boundVolSizeArr)

coords.dispose()

// -- Extract 3d object (e.g. brain)
const cropped_slices_3d = slices_3d.slice(
[row_min, col_min, depth_min],
Expand Down
110 changes: 55 additions & 55 deletions brainchop-webworker.js
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,58 @@ async function load_model(modelUrl) {
return await tf.loadLayersModel(modelUrl)
}

// return first and last non-zero voxel in row (dim = 0), column (1) or slice (2) dimension
async function firstLastNonZero(tensor3D, dim = 0) {
let mxs = []
if (dim === 0) {
mxs = await tensor3D.max(2).max(1).arraySync()
} else if (dim === 1) {
mxs = await tensor3D.max(2).max(0).arraySync()
} else {
mxs = await tensor3D.max(1).max(0).arraySync()
}
let mn = mxs.length
let mx = 0
for (let i = 0; i < mxs.length; i++) {
if (mxs[i] > 0) {
mn = i
break
}
}
for (let i = mxs.length - 1; i >= 0; i--) {
if (mxs[i] > 0) {
mx = i
break
}
}
return [mn, mx]
}

async function firstLastNonZero3D(tensor3D) {
const [row_min, row_max] = await firstLastNonZero(tensor3D, 0)
const [col_min, col_max] = await firstLastNonZero(tensor3D, 1)
const [depth_min, depth_max] = await firstLastNonZero(tensor3D, 2)
console.log('row min and max :', row_min, row_max)
console.log('col min and max :', col_min, col_max)
console.log('depth min and max :', depth_min, depth_max)
return [row_min, row_max, col_min, col_max, depth_min, depth_max]
}

/*
//simpler function, but x4 slower
async function firstLastNonZero3D(tensor3D) {
const coords = await tf.whereAsync(tensor3D)
const row_min = coords.min(0).arraySync()[0]
const row_max = coords.max(0).arraySync()[0]
const col_min = coords.min(0).arraySync()[1]
const col_max = coords.max(0).arraySync()[1]
const depth_min = coords.min(0).arraySync()[2]
const depth_max = coords.max(0).arraySync()[2]
coords.dispose()
return [row_min, row_max, col_min, col_max, depth_min, depth_max]
}
*/

async function getAllSlicesDataAsTF3D(num_of_slices, niftiHeader, niftiImage) {
// Get nifti dimensions
const cols = niftiHeader.dims[1] // Slice width
Expand Down Expand Up @@ -734,28 +786,13 @@ async function inferenceFullVolumeSeqCovLayerPhase2(
}

console.log(' mask_3d shape : ', mask_3d.shape)
const coords = await tf.whereAsync(mask_3d)
// -- Get each voxel coords (x, y, z)

const [row_min, row_max, col_min, col_max, depth_min, depth_max] = await firstLastNonZero3D(mask_3d)
mask_3d.dispose()
const row_min = coords.min(0).arraySync()[0]
const row_max = coords.max(0).arraySync()[0]
const col_min = coords.min(0).arraySync()[1]
const col_max = coords.max(0).arraySync()[1]
const depth_min = coords.min(0).arraySync()[2]
const depth_max = coords.max(0).arraySync()[2]

console.log('row min and max :', row_min, row_max)
console.log('col min and max :', col_min, col_max)
console.log('depth min and max :', depth_min, depth_max)

// -- Reference voxel that cropped volume started slice with it
const refVoxel = [row_min, col_min, depth_min]
// -- Starting form refVoxel, size of bounding volume
const boundVolSizeArr = [row_max - row_min + 1, col_max - col_min + 1, depth_max - depth_min + 1]

coords.dispose()

// -- Extract 3d object (e.g. brain)
const cropped_slices_3d = await slices_3d.slice(
[row_min, col_min, depth_min],
Expand Down Expand Up @@ -1114,42 +1151,8 @@ async function inferenceFullVolumePhase2(
// -- pipeline1_out.dispose()
}
console.log(' mask_3d shape : ', mask_3d.shape)
const coords = await tf.whereAsync(mask_3d)
// -- Get each voxel coords (x, y, z)
const [row_min, row_max, col_min, col_max, depth_min, depth_max] = await firstLastNonZero3D(mask_3d)
mask_3d.dispose()
const coordsArr = coords.arraySync()

let row_min = slice_height
let row_max = 0
let col_min = slice_width
let col_max = 0
let depth_min = num_of_slices
let depth_max = 0

for (let i = 0; i < coordsArr.length; i++) {
if (row_min > coordsArr[i][0]) {
row_min = coordsArr[i][0]
} else if (row_max < coordsArr[i][0]) {
row_max = coordsArr[i][0]
}

if (col_min > coordsArr[i][1]) {
col_min = coordsArr[i][1]
} else if (col_max < coordsArr[i][1]) {
col_max = coordsArr[i][1]
}

if (depth_min > coordsArr[i][2]) {
depth_min = coordsArr[i][2]
} else if (depth_max < coordsArr[i][2]) {
depth_max = coordsArr[i][2]
}
}

console.log('row min and max :', row_min, row_max)
console.log('col min and max :', col_min, col_max)
console.log('depth min and max :', depth_min, depth_max)

// -- Reference voxel that cropped volume started slice with it
const refVoxel = [row_min, col_min, depth_min]
console.log('refVoxel :', refVoxel)
Expand All @@ -1158,9 +1161,6 @@ async function inferenceFullVolumePhase2(
const boundVolSizeArr = [row_max - row_min + 1, col_max - col_min + 1, depth_max - depth_min + 1]

console.log('boundVolSizeArr :', boundVolSizeArr)

coords.dispose()

// -- Extract 3d object (e.g. brain)
const cropped_slices_3d = slices_3d.slice(
[row_min, col_min, depth_min],
Expand Down Expand Up @@ -2096,4 +2096,4 @@ self.addEventListener(
runInferenceWW(event.data.opts, event.data.modelEntry, event.data.niftiHeader, event.data.niftiImage)
},
false
)
)

0 comments on commit 3e61e69

Please sign in to comment.