From dc2059472bc61c15a1587fd738de846090cd5f88 Mon Sep 17 00:00:00 2001 From: Raffi Khatchadourian Date: Tue, 6 Aug 2024 12:29:47 -0400 Subject: [PATCH] Add transformations. --- nobrainer/layers/groupnorm.py | 2 ++ nobrainer/layers/padding.py | 1 + nobrainer/models/unetr.py | 5 +++++ nobrainer/transform.py | 3 +++ 4 files changed, 11 insertions(+) diff --git a/nobrainer/layers/groupnorm.py b/nobrainer/layers/groupnorm.py index aea7413a..cb19420e 100644 --- a/nobrainer/layers/groupnorm.py +++ b/nobrainer/layers/groupnorm.py @@ -26,6 +26,7 @@ from tensorflow.keras import constraints, initializers, regularizers from tensorflow.keras.layers import InputSpec, Layer from tensorflow.keras.utils import get_custom_objects +from tensorflow import function class GroupNormalization(Layer): @@ -143,6 +144,7 @@ def build(self, input_shape): self.beta = None self.built = True + @function def call(self, inputs, **kwargs): input_shape = K.int_shape(inputs) tensor_input_shape = K.shape(inputs) diff --git a/nobrainer/layers/padding.py b/nobrainer/layers/padding.py index 5994f65f..6d93c6e5 100644 --- a/nobrainer/layers/padding.py +++ b/nobrainer/layers/padding.py @@ -16,5 +16,6 @@ def __init__(self, padding, **kwds): self._paddings = [[0, 0], [0, 0], [0, 0], [0, 0], [self.padding, self.padding]] super(ZeroPadding3DChannels, self).__init__(**kwds) + @tf.function def call(self, x): return tf.pad(x, paddings=self._paddings, mode="CONSTANT", constant_values=0) diff --git a/nobrainer/models/unetr.py b/nobrainer/models/unetr.py index 3969199b..4972850d 100644 --- a/nobrainer/models/unetr.py +++ b/nobrainer/models/unetr.py @@ -19,6 +19,7 @@ def __init__(self, filters): output_padding=None, ) + @tf.function def call(self, inputs): return self.block(inputs) @@ -63,6 +64,7 @@ def __init__(self, filters, kernel_size=(3, 3, 3)): ] ) + @tf.function def call(self, inputs): return self.a(inputs) @@ -172,6 +174,7 @@ def __init__(self, cube_size, patch_size, embed_dim): ) # embedding - basically is adding numerical embedding to the layer along with an extra dim + @tf.function def call(self, inputs): patches = self.lyer(inputs) patches = tf.reshape( @@ -238,6 +241,7 @@ def __init__( for _ in range(num_layers) ] + @tf.function def call(self, inputs, training=True): extract_layers = [] x = inputs @@ -320,6 +324,7 @@ def __init__( [Conv3DBlock(64), Conv3DBlock(64), SingleConv3DBlock(output_dim, (1, 1, 1))] ) + @tf.function def call(self, x): z = self.transformer(x) z0, z3, z6, z9, z12 = x, z[0], z[1], z[2], z[3] diff --git a/nobrainer/transform.py b/nobrainer/transform.py index 2480c088..ed06c46e 100644 --- a/nobrainer/transform.py +++ b/nobrainer/transform.py @@ -211,12 +211,14 @@ def _get_coordinates(volume_shape): return tf.reshape(tf.stack(out, axis=3), shape=(-1, 3)) +@tf.function def _nearest_neighbor_interpolation(volume, coords): """Three-dimensional nearest neighbors interpolation.""" volume_f = _get_voxels(volume=volume, coords=tf.round(coords)) return tf.reshape(volume_f, volume.shape) +@tf.function def _trilinear_interpolation(volume, coords): """Trilinear interpolation. @@ -301,6 +303,7 @@ def _trilinear_interpolation(volume, coords): return tf.reshape(c, volume.shape) +@tf.function def _get_voxels(volume, coords): """Get voxels from volume at points. These voxels are in a flat tensor.""" x = tf.cast(volume, tf.float32)