diff --git a/gcvit/layers/attention.py b/gcvit/layers/attention.py index 9f15deb..f7a7ba8 100644 --- a/gcvit/layers/attention.py +++ b/gcvit/layers/attention.py @@ -33,7 +33,7 @@ def build(self, input_shape): dim * self.qkv_size, use_bias=self.qkv_bias, name="qkv" ) self.relative_position_bias_table = self.add_weight( - "relative_position_bias_table", + name="relative_position_bias_table", shape=[ (2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), self.num_heads, @@ -50,7 +50,7 @@ def build(self, input_shape): self.proj_dropout, name="proj_drop" ) self.softmax = tf.keras.layers.Activation("softmax", name="softmax") - self.relative_position_index = self.get_relative_position_index() + # self.relative_position_index = self.get_relative_position_index() super().build(input_shape) def get_relative_position_index(self): @@ -101,7 +101,7 @@ def call(self, inputs, **kwargs): attn = q @ tf.transpose(k, perm=[0, 1, 3, 2]) relative_position_bias = tf.gather( self.relative_position_bias_table, - tf.reshape(self.relative_position_index, shape=[-1]), + tf.reshape(self.get_relative_position_index(), shape=[-1]), ) relative_position_bias = tf.reshape( relative_position_bias, diff --git a/gcvit/layers/block.py b/gcvit/layers/block.py index 25f8a4a..dbee207 100644 --- a/gcvit/layers/block.py +++ b/gcvit/layers/block.py @@ -65,14 +65,14 @@ def build(self, input_shape): ) if self.layer_scale is not None: self.gamma1 = self.add_weight( - "gamma1", + name="gamma1", shape=[C], initializer=tf.keras.initializers.Constant(self.layer_scale), trainable=True, dtype=self.dtype, ) self.gamma2 = self.add_weight( - "gamma2", + name="gamma2", shape=[C], initializer=tf.keras.initializers.Constant(self.layer_scale), trainable=True, diff --git a/gcvit/layers/feature.py b/gcvit/layers/feature.py index 8091a8c..f1eab44 100644 --- a/gcvit/layers/feature.py +++ b/gcvit/layers/feature.py @@ -67,11 +67,11 @@ def build(self, input_shape): self.avg_pool = AdaptiveAveragePooling2D(1, name="avg_pool") self.fc = [ tf.keras.layers.Dense( - int(inp * self.expansion), use_bias=False, name="fc/0" + int(inp * self.expansion), use_bias=False, name="fc_0" ), - tf.keras.layers.Activation("gelu", name="fc/1"), - tf.keras.layers.Dense(self.oup, use_bias=False, name="fc/2"), - tf.keras.layers.Activation("sigmoid", name="fc/3"), + tf.keras.layers.Activation("gelu", name="fc_1"), + tf.keras.layers.Dense(self.oup, use_bias=False, name="fc_2"), + tf.keras.layers.Activation("sigmoid", name="fc_3"), ] super().build(input_shape) @@ -111,17 +111,17 @@ def build(self, input_shape): strides=1, padding="valid", use_bias=False, - name="conv/0", + name="conv_0", ), - tf.keras.layers.Activation("gelu", name="conv/1"), - SE(name="conv/2"), + tf.keras.layers.Activation("gelu", name="conv_1"), + SE(name="conv_2"), tf.keras.layers.Conv2D( dim, kernel_size=1, strides=1, padding="valid", use_bias=False, - name="conv/3", + name="conv_3", ), ] self.reduction = tf.keras.layers.Conv2D( @@ -179,17 +179,17 @@ def build(self, input_shape): strides=1, padding="valid", use_bias=False, - name="conv/0", + name="conv_0", ), - tf.keras.layers.Activation("gelu", name="conv/1"), - SE(name="conv/2"), + tf.keras.layers.Activation("gelu", name="conv_1"), + SE(name="conv_2"), tf.keras.layers.Conv2D( dim, kernel_size=1, strides=1, padding="valid", use_bias=False, - name="conv/3", + name="conv_3", ), ] if not self.keep_dim: @@ -237,7 +237,7 @@ def __init__(self, keep_dims=False, **kwargs): def build(self, input_shape): self.to_q_global = [ - FeatExtract(keep_dim, name=f"to_q_global/{i}") + FeatExtract(keep_dim, name=f"to_q_global_{i}") for i, keep_dim in enumerate(self.keep_dims) ] super().build(input_shape) diff --git a/gcvit/layers/level.py b/gcvit/layers/level.py index eeff7b2..abd3fbe 100644 --- a/gcvit/layers/level.py +++ b/gcvit/layers/level.py @@ -59,7 +59,7 @@ def build(self, input_shape): attn_drop=self.attn_drop, path_drop=path_drop[i], layer_scale=self.layer_scale, - name=f"blocks/{i}", + name=f"blocks_{i}", ) for i in range(self.depth) ] diff --git a/gcvit/models/gcvit.py b/gcvit/models/gcvit.py index 3ca72a9..cec382e 100644 --- a/gcvit/models/gcvit.py +++ b/gcvit/models/gcvit.py @@ -142,7 +142,7 @@ def __init__( path_drop=path_drop, layer_scale=layer_scale, resize_query=resize_query, - name=f"levels/{i}", + name=f"levels_{i}", ) self.levels.append(level) self.norm = tf.keras.layers.LayerNormalization( @@ -227,7 +227,7 @@ def get_config(self): "resize_query": self.resize_query, "global_pool": self.global_pool, "num_classes": self.num_classes, - "head_act": self.head_act + "head_act": self.head_act, } ) return config