Skip to content

Commit

Permalink
Add transformations.
Browse files Browse the repository at this point in the history
  • Loading branch information
khatchad committed Aug 6, 2024
1 parent 231ec54 commit dc20594
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 0 deletions.
2 changes: 2 additions & 0 deletions nobrainer/layers/groupnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from tensorflow.keras import constraints, initializers, regularizers
from tensorflow.keras.layers import InputSpec, Layer
from tensorflow.keras.utils import get_custom_objects
from tensorflow import function


class GroupNormalization(Layer):
Expand Down Expand Up @@ -143,6 +144,7 @@ def build(self, input_shape):
self.beta = None
self.built = True

@function
def call(self, inputs, **kwargs):
input_shape = K.int_shape(inputs)
tensor_input_shape = K.shape(inputs)
Expand Down
1 change: 1 addition & 0 deletions nobrainer/layers/padding.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,6 @@ def __init__(self, padding, **kwds):
self._paddings = [[0, 0], [0, 0], [0, 0], [0, 0], [self.padding, self.padding]]
super(ZeroPadding3DChannels, self).__init__(**kwds)

@tf.function
def call(self, x):
return tf.pad(x, paddings=self._paddings, mode="CONSTANT", constant_values=0)
5 changes: 5 additions & 0 deletions nobrainer/models/unetr.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def __init__(self, filters):
output_padding=None,
)

@tf.function
def call(self, inputs):
return self.block(inputs)

Expand Down Expand Up @@ -63,6 +64,7 @@ def __init__(self, filters, kernel_size=(3, 3, 3)):
]
)

@tf.function
def call(self, inputs):
return self.a(inputs)

Expand Down Expand Up @@ -172,6 +174,7 @@ def __init__(self, cube_size, patch_size, embed_dim):
)
# embedding - basically is adding numerical embedding to the layer along with an extra dim

@tf.function
def call(self, inputs):
patches = self.lyer(inputs)
patches = tf.reshape(
Expand Down Expand Up @@ -238,6 +241,7 @@ def __init__(
for _ in range(num_layers)
]

@tf.function
def call(self, inputs, training=True):
extract_layers = []
x = inputs
Expand Down Expand Up @@ -320,6 +324,7 @@ def __init__(
[Conv3DBlock(64), Conv3DBlock(64), SingleConv3DBlock(output_dim, (1, 1, 1))]
)

@tf.function
def call(self, x):
z = self.transformer(x)
z0, z3, z6, z9, z12 = x, z[0], z[1], z[2], z[3]
Expand Down
3 changes: 3 additions & 0 deletions nobrainer/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,12 +211,14 @@ def _get_coordinates(volume_shape):
return tf.reshape(tf.stack(out, axis=3), shape=(-1, 3))


@tf.function
def _nearest_neighbor_interpolation(volume, coords):
"""Three-dimensional nearest neighbors interpolation."""
volume_f = _get_voxels(volume=volume, coords=tf.round(coords))
return tf.reshape(volume_f, volume.shape)


@tf.function
def _trilinear_interpolation(volume, coords):
"""Trilinear interpolation.
Expand Down Expand Up @@ -301,6 +303,7 @@ def _trilinear_interpolation(volume, coords):
return tf.reshape(c, volume.shape)


@tf.function
def _get_voxels(volume, coords):
"""Get voxels from volume at points. These voxels are in a flat tensor."""
x = tf.cast(volume, tf.float32)
Expand Down

0 comments on commit dc20594

Please sign in to comment.