Skip to content

Commit

Permalink
Fix DeiT
Browse files Browse the repository at this point in the history
  • Loading branch information
Rocketknight1 committed Dec 6, 2023
1 parent 6bba1b4 commit 4128761
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions src/transformers/models/deit/modeling_tf_deit.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,6 +836,7 @@ def __init__(self, config: DeiTConfig, **kwargs) -> None:
filters=config.encoder_stride**2 * config.num_channels, kernel_size=1, name="0"
)
self.pixel_shuffle = TFDeitPixelShuffle(config.encoder_stride, name="1")
self.config = config

def call(self, inputs: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = inputs
Expand All @@ -849,7 +850,7 @@ def build(self, input_shape=None):
self.built = True
if getattr(self, "conv2d", None) is not None:
with tf.name_scope(self.conv2d.name):
self.conv2d.build(None)
self.conv2d.build(self.config.hidden_size)
if getattr(self, "pixel_shuffle", None) is not None:
with tf.name_scope(self.pixel_shuffle.name):
self.pixel_shuffle.build(None)
Expand Down Expand Up @@ -999,6 +1000,7 @@ def __init__(self, config: DeiTConfig):
if config.num_labels > 0
else tf.keras.layers.Activation("linear", name="classifier")
)
self.config = config

@unpack_inputs
@add_start_docstrings_to_model_forward(DEIT_INPUTS_DOCSTRING)
Expand Down Expand Up @@ -1084,7 +1086,7 @@ def build(self, input_shape=None):
self.deit.build(None)
if getattr(self, "classifier", None) is not None:
with tf.name_scope(self.classifier.name):
self.classifier.build(None)
self.classifier.build(self.config.hidden_size)


@add_start_docstrings(
Expand Down Expand Up @@ -1117,6 +1119,7 @@ def __init__(self, config: DeiTConfig) -> None:
if config.num_labels > 0
else tf.keras.layers.Activation("linear", name="distillation_classifier")
)
self.config = config

@unpack_inputs
@add_start_docstrings_to_model_forward(DEIT_INPUTS_DOCSTRING)
Expand Down Expand Up @@ -1175,7 +1178,7 @@ def build(self, input_shape=None):
self.deit.build(None)
if getattr(self, "cls_classifier", None) is not None:
with tf.name_scope(self.cls_classifier.name):
self.cls_classifier.build(None)
self.cls_classifier.build(self.config.hidden_size)
if getattr(self, "distillation_classifier", None) is not None:
with tf.name_scope(self.distillation_classifier.name):
self.distillation_classifier.build(None)
self.distillation_classifier.build(self.config.hidden_size)

0 comments on commit 4128761

Please sign in to comment.