Skip to content

Commit

Permalink
Revert back stage selection logic
Browse files Browse the repository at this point in the history
  • Loading branch information
amyeroberts committed Nov 28, 2023
1 parent 6ff59f0 commit 1db3de3
Show file tree
Hide file tree
Showing 12 changed files with 82 additions and 84 deletions.
5 changes: 3 additions & 2 deletions src/transformers/models/bit/modeling_bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,8 +883,9 @@ def forward(
hidden_states = outputs.hidden_states

feature_maps = ()
for idx in self.out_indices:
feature_maps += (hidden_states[idx],)
for idx, stage in enumerate(self.stage_names):
if stage in self.out_features:
feature_maps += (hidden_states[idx],)

if not return_dict:
output = (feature_maps,)
Expand Down
10 changes: 4 additions & 6 deletions src/transformers/models/convnext/modeling_convnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,12 +535,10 @@ def forward(
hidden_states = outputs.hidden_states if return_dict else outputs[1]

feature_maps = ()
for stage in self.out_features:
idx = self.stage_names.index(stage)
if idx == 0:
raise ValueError("The stem should not be returned as a feature map.")
hidden_state = self.hidden_states_norms[stage](hidden_states[idx])
feature_maps += (hidden_state,)
for stage, hidden_state in zip(self.stage_names, hidden_states):
if stage in self.out_features:
hidden_state = self.hidden_states_norms[stage](hidden_state)
feature_maps += (hidden_state,)

if not return_dict:
output = (feature_maps,)
Expand Down
10 changes: 4 additions & 6 deletions src/transformers/models/convnextv2/modeling_convnextv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,12 +558,10 @@ def forward(
hidden_states = outputs.hidden_states if return_dict else outputs[1]

feature_maps = ()
for stage in self.out_features:
idx = self.stage_names.index(stage)
if idx == 0:
raise ValueError("The stem should not be returned as a feature map.")
hidden_state = self.hidden_states_norms[stage](hidden_states[idx])
feature_maps += (hidden_state,)
for stage, hidden_state in zip(self.stage_names, hidden_states):
if stage in self.out_features:
hidden_state = self.hidden_states_norms[stage](hidden_state)
feature_maps += (hidden_state,)

if not return_dict:
output = (feature_maps,)
Expand Down
19 changes: 9 additions & 10 deletions src/transformers/models/dinat/modeling_dinat.py
Original file line number Diff line number Diff line change
Expand Up @@ -955,16 +955,15 @@ def forward(
hidden_states = outputs.reshaped_hidden_states

feature_maps = ()
for stage in self.out_features:
idx = self.stage_names.index(stage)
hidden_state = hidden_states[idx]
batch_size, num_channels, height, width = hidden_state.shape
hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous()
hidden_state = hidden_state.view(batch_size, height * width, num_channels)
hidden_state = self.hidden_states_norms[stage](hidden_state)
hidden_state = hidden_state.view(batch_size, height, width, num_channels)
hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
feature_maps += (hidden_state,)
for stage, hidden_state in zip(self.stage_names, hidden_states):
if stage in self.out_features:
batch_size, num_channels, height, width = hidden_state.shape
hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous()
hidden_state = hidden_state.view(batch_size, height * width, num_channels)
hidden_state = self.hidden_states_norms[stage](hidden_state)
hidden_state = hidden_state.view(batch_size, height, width, num_channels)
hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
feature_maps += (hidden_state,)

if not return_dict:
output = (feature_maps,)
Expand Down
27 changes: 13 additions & 14 deletions src/transformers/models/dinov2/modeling_dinov2.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,20 +830,19 @@ def forward(
hidden_states = outputs.hidden_states if return_dict else outputs[1]

feature_maps = ()
for stage in self.out_features:
idx = self.stage_names.index(stage)
hidden_state = hidden_states[idx]
if self.config.apply_layernorm:
hidden_state = self.layernorm(hidden_state)
if self.config.reshape_hidden_states:
hidden_state = hidden_state[:, 1:]
# this was actually a bug in the original implementation that we copied here,
# cause normally the order is height, width
batch_size, _, height, width = pixel_values.shape
patch_size = self.config.patch_size
hidden_state = hidden_state.reshape(batch_size, height // patch_size, width // patch_size, -1)
hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
feature_maps += (hidden_state,)
for stage, hidden_state in zip(self.stage_names, hidden_states):
if stage in self.out_features:
if self.config.apply_layernorm:
hidden_state = self.layernorm(hidden_state)
if self.config.reshape_hidden_states:
hidden_state = hidden_state[:, 1:]
# this was actually a bug in the original implementation that we copied here,
# cause normally the order is height, width
batch_size, _, height, width = pixel_values.shape
patch_size = self.config.patch_size
hidden_state = hidden_state.reshape(batch_size, height // patch_size, width // patch_size, -1)
hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
feature_maps += (hidden_state,)

if not return_dict:
if output_hidden_states:
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/models/focalnet/modeling_focalnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -1018,9 +1018,9 @@ def forward(
hidden_states = outputs.reshaped_hidden_states

feature_maps = ()
for stage in self.out_features:
idx = self.stage_names.index(stage)
feature_maps += (hidden_states[idx],)
for idx, stage in enumerate(self.stage_names):
if stage in self.out_features:
feature_maps += (hidden_states[idx],)

if not return_dict:
output = (feature_maps,)
Expand Down
19 changes: 9 additions & 10 deletions src/transformers/models/maskformer/modeling_maskformer_swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,19 +874,17 @@ def forward(
pixel_values, output_hidden_states=True, output_attentions=output_attentions, return_dict=True
)

hidden_states = outputs.hidden_states
# we skip the stem
hidden_states = outputs.hidden_states[1:]

# we need to reshape the hidden states to their original spatial dimensions
# spatial dimensions contains all the heights and widths of each stage, including after the embeddings
spatial_dimensions: Tuple[Tuple[int, int]] = outputs.hidden_states_spatial_dimensions
feature_maps = ()
for stage in self.out_features:
idx = self.stage_names.index(stage)
# we skip the stem
if idx == 0:
raise ValueError("The stem is not supported.")
hidden_state = hidden_states[idx]
(height, width) = outputs.hidden_states_spatial_dimensions[idx - 1]
norm = self.hidden_states_norms[idx - 1]
for i, (hidden_state, stage, (height, width)) in enumerate(
zip(hidden_states, self.stage_names[1:], spatial_dimensions)
):
norm = self.hidden_states_norms[i]
# the last element corespond to the layer's last block output but before patch merging
hidden_state_unpolled = hidden_state[-1]
hidden_state_norm = norm(hidden_state_unpolled)
Expand All @@ -896,7 +894,8 @@ def forward(
hidden_state_permuted = (
hidden_state_norm.permute(0, 2, 1).view((batch_size, hidden_size, height, width)).contiguous()
)
feature_maps += (hidden_state_permuted,)
if stage in self.out_features:
feature_maps += (hidden_state_permuted,)

if not return_dict:
output = (feature_maps,)
Expand Down
20 changes: 10 additions & 10 deletions src/transformers/models/nat/modeling_nat.py
Original file line number Diff line number Diff line change
Expand Up @@ -934,16 +934,16 @@ def forward(
hidden_states = outputs.reshaped_hidden_states

feature_maps = ()
for stage in self.out_features:
idx = self.stage_names.index(stage)
hidden_state = hidden_states[idx]
batch_size, num_channels, height, width = hidden_state.shape
hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous()
hidden_state = hidden_state.view(batch_size, height * width, num_channels)
hidden_state = self.hidden_states_norms[stage](hidden_state)
hidden_state = hidden_state.view(batch_size, height, width, num_channels)
hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
feature_maps += (hidden_state,)
for stage, hidden_state in zip(self.stage_names, hidden_states):
if stage in self.out_features:
# TODO can we simplify this?
batch_size, num_channels, height, width = hidden_state.shape
hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous()
hidden_state = hidden_state.view(batch_size, height * width, num_channels)
hidden_state = self.hidden_states_norms[stage](hidden_state)
hidden_state = hidden_state.view(batch_size, height, width, num_channels)
hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
feature_maps += (hidden_state,)

if not return_dict:
output = (feature_maps,)
Expand Down
5 changes: 3 additions & 2 deletions src/transformers/models/resnet/modeling_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,8 +496,9 @@ def forward(
hidden_states = outputs.hidden_states

feature_maps = ()
for idx in self.out_indices:
feature_maps += (hidden_states[idx],)
for idx, stage in enumerate(self.stage_names):
if stage in self.out_features:
feature_maps += (hidden_states[idx],)

if not return_dict:
output = (feature_maps,)
Expand Down
19 changes: 9 additions & 10 deletions src/transformers/models/swin/modeling_swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1326,16 +1326,15 @@ def forward(
hidden_states = outputs.reshaped_hidden_states

feature_maps = ()
for stage in self.out_features:
idx = self.stage_names.index(stage)
hidden_state = hidden_states[idx]
batch_size, num_channels, height, width = hidden_state.shape
hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous()
hidden_state = hidden_state.view(batch_size, height * width, num_channels)
hidden_state = self.hidden_states_norms[stage](hidden_state)
hidden_state = hidden_state.view(batch_size, height, width, num_channels)
hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
feature_maps += (hidden_state,)
for stage, hidden_state in zip(self.stage_names, hidden_states):
if stage in self.out_features:
batch_size, num_channels, height, width = hidden_state.shape
hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous()
hidden_state = hidden_state.view(batch_size, height * width, num_channels)
hidden_state = self.hidden_states_norms[stage](hidden_state)
hidden_state = hidden_state.view(batch_size, height, width, num_channels)
hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
feature_maps += (hidden_state,)

if not return_dict:
output = (feature_maps,)
Expand Down
20 changes: 12 additions & 8 deletions src/transformers/models/timm_backbone/modeling_timm_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,10 @@ def __init__(self, config, **kwargs):
if getattr(config, "freeze_batch_norm_2d", False):
self.freeze_batch_norm_2d()

# We force the backbone to return all of the hidden states and then filter them in the forward pass.
# This is to match the behavior of the other backbones in the library.
all_layers = {layer["module"]: str(i) for i, layer in enumerate(self._backbone.feature_info.info)}
self._backbone.return_layers = all_layers
# These are used to control the output of the model when called. If output_hidden_states is True, then
# return_layers is modified to include all layers.
self._return_layers = self._backbone.return_layers
self._all_layers = {layer["module"]: str(i) for i, layer in enumerate(self._backbone.feature_info.info)}
super()._init_backbone(config)

@classmethod
Expand Down Expand Up @@ -136,10 +136,14 @@ def forward(
if output_attentions:
raise ValueError("Cannot output attentions for timm backbones at the moment")

hidden_states = self._backbone(pixel_values, **kwargs)
feature_maps = tuple(hidden_states[i] for i in self.out_indices)

if not output_hidden_states:
if output_hidden_states:
# We modify the return layers to include all the stages of the backbone
self._backbone.return_layers = self._all_layers
hidden_states = self._backbone(pixel_values, **kwargs)
self._backbone.return_layers = self._return_layers
feature_maps = tuple(hidden_states[i] for i in self.out_indices)
else:
feature_maps = self._backbone(pixel_values, **kwargs)
hidden_states = None

feature_maps = tuple(feature_maps)
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/models/vitdet/modeling_vitdet.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,9 +858,9 @@ def forward(
hidden_states = outputs.hidden_states if return_dict else outputs[1]

feature_maps = ()
for stage in self.out_features:
idx = self.stage_names.index(stage)
feature_maps += (hidden_states[idx],)
for stage, hidden_state in zip(self.stage_names, hidden_states):
if stage in self.out_features:
feature_maps += (hidden_state,)

if not return_dict:
if output_hidden_states:
Expand Down

0 comments on commit 1db3de3

Please sign in to comment.