diff --git a/src/transformers/models/bit/modeling_bit.py b/src/transformers/models/bit/modeling_bit.py index 90717dd1d3ebd4..49bc75b5f0aa6b 100644 --- a/src/transformers/models/bit/modeling_bit.py +++ b/src/transformers/models/bit/modeling_bit.py @@ -883,8 +883,9 @@ def forward( hidden_states = outputs.hidden_states feature_maps = () - for idx in self.out_indices: - feature_maps += (hidden_states[idx],) + for idx, stage in enumerate(self.stage_names): + if stage in self.out_features: + feature_maps += (hidden_states[idx],) if not return_dict: output = (feature_maps,) diff --git a/src/transformers/models/convnext/modeling_convnext.py b/src/transformers/models/convnext/modeling_convnext.py index fe8e63d6834f03..a952e5d8165e15 100755 --- a/src/transformers/models/convnext/modeling_convnext.py +++ b/src/transformers/models/convnext/modeling_convnext.py @@ -535,12 +535,10 @@ def forward( hidden_states = outputs.hidden_states if return_dict else outputs[1] feature_maps = () - for stage in self.out_features: - idx = self.stage_names.index(stage) - if idx == 0: - raise ValueError("The stem should not be returned as a feature map.") - hidden_state = self.hidden_states_norms[stage](hidden_states[idx]) - feature_maps += (hidden_state,) + for stage, hidden_state in zip(self.stage_names, hidden_states): + if stage in self.out_features: + hidden_state = self.hidden_states_norms[stage](hidden_state) + feature_maps += (hidden_state,) if not return_dict: output = (feature_maps,) diff --git a/src/transformers/models/convnextv2/modeling_convnextv2.py b/src/transformers/models/convnextv2/modeling_convnextv2.py index 8798fde7d60d70..8d166200d12253 100644 --- a/src/transformers/models/convnextv2/modeling_convnextv2.py +++ b/src/transformers/models/convnextv2/modeling_convnextv2.py @@ -558,12 +558,10 @@ def forward( hidden_states = outputs.hidden_states if return_dict else outputs[1] feature_maps = () - for stage in self.out_features: - idx = self.stage_names.index(stage) - if idx == 0: - raise ValueError("The stem should not be returned as a feature map.") - hidden_state = self.hidden_states_norms[stage](hidden_states[idx]) - feature_maps += (hidden_state,) + for stage, hidden_state in zip(self.stage_names, hidden_states): + if stage in self.out_features: + hidden_state = self.hidden_states_norms[stage](hidden_state) + feature_maps += (hidden_state,) if not return_dict: output = (feature_maps,) diff --git a/src/transformers/models/dinat/modeling_dinat.py b/src/transformers/models/dinat/modeling_dinat.py index 1b1c78806115f4..aae79e0452a2d7 100644 --- a/src/transformers/models/dinat/modeling_dinat.py +++ b/src/transformers/models/dinat/modeling_dinat.py @@ -955,16 +955,15 @@ def forward( hidden_states = outputs.reshaped_hidden_states feature_maps = () - for stage in self.out_features: - idx = self.stage_names.index(stage) - hidden_state = hidden_states[idx] - batch_size, num_channels, height, width = hidden_state.shape - hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous() - hidden_state = hidden_state.view(batch_size, height * width, num_channels) - hidden_state = self.hidden_states_norms[stage](hidden_state) - hidden_state = hidden_state.view(batch_size, height, width, num_channels) - hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous() - feature_maps += (hidden_state,) + for stage, hidden_state in zip(self.stage_names, hidden_states): + if stage in self.out_features: + batch_size, num_channels, height, width = hidden_state.shape + hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous() + hidden_state = hidden_state.view(batch_size, height * width, num_channels) + hidden_state = self.hidden_states_norms[stage](hidden_state) + hidden_state = hidden_state.view(batch_size, height, width, num_channels) + hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous() + feature_maps += (hidden_state,) if not return_dict: output = (feature_maps,) diff --git a/src/transformers/models/dinov2/modeling_dinov2.py b/src/transformers/models/dinov2/modeling_dinov2.py index d82743d15c34da..66bac639f6731b 100644 --- a/src/transformers/models/dinov2/modeling_dinov2.py +++ b/src/transformers/models/dinov2/modeling_dinov2.py @@ -830,20 +830,19 @@ def forward( hidden_states = outputs.hidden_states if return_dict else outputs[1] feature_maps = () - for stage in self.out_features: - idx = self.stage_names.index(stage) - hidden_state = hidden_states[idx] - if self.config.apply_layernorm: - hidden_state = self.layernorm(hidden_state) - if self.config.reshape_hidden_states: - hidden_state = hidden_state[:, 1:] - # this was actually a bug in the original implementation that we copied here, - # cause normally the order is height, width - batch_size, _, height, width = pixel_values.shape - patch_size = self.config.patch_size - hidden_state = hidden_state.reshape(batch_size, height // patch_size, width // patch_size, -1) - hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous() - feature_maps += (hidden_state,) + for stage, hidden_state in zip(self.stage_names, hidden_states): + if stage in self.out_features: + if self.config.apply_layernorm: + hidden_state = self.layernorm(hidden_state) + if self.config.reshape_hidden_states: + hidden_state = hidden_state[:, 1:] + # this was actually a bug in the original implementation that we copied here, + # cause normally the order is height, width + batch_size, _, height, width = pixel_values.shape + patch_size = self.config.patch_size + hidden_state = hidden_state.reshape(batch_size, height // patch_size, width // patch_size, -1) + hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous() + feature_maps += (hidden_state,) if not return_dict: if output_hidden_states: diff --git a/src/transformers/models/focalnet/modeling_focalnet.py b/src/transformers/models/focalnet/modeling_focalnet.py index 9bba06cd51dbc7..b0033c855985e7 100644 --- a/src/transformers/models/focalnet/modeling_focalnet.py +++ b/src/transformers/models/focalnet/modeling_focalnet.py @@ -1018,9 +1018,9 @@ def forward( hidden_states = outputs.reshaped_hidden_states feature_maps = () - for stage in self.out_features: - idx = self.stage_names.index(stage) - feature_maps += (hidden_states[idx],) + for idx, stage in enumerate(self.stage_names): + if stage in self.out_features: + feature_maps += (hidden_states[idx],) if not return_dict: output = (feature_maps,) diff --git a/src/transformers/models/maskformer/modeling_maskformer_swin.py b/src/transformers/models/maskformer/modeling_maskformer_swin.py index 00fa5ce90cd6ae..b4714860e6bffb 100644 --- a/src/transformers/models/maskformer/modeling_maskformer_swin.py +++ b/src/transformers/models/maskformer/modeling_maskformer_swin.py @@ -874,19 +874,17 @@ def forward( pixel_values, output_hidden_states=True, output_attentions=output_attentions, return_dict=True ) - hidden_states = outputs.hidden_states + # we skip the stem + hidden_states = outputs.hidden_states[1:] # we need to reshape the hidden states to their original spatial dimensions # spatial dimensions contains all the heights and widths of each stage, including after the embeddings + spatial_dimensions: Tuple[Tuple[int, int]] = outputs.hidden_states_spatial_dimensions feature_maps = () - for stage in self.out_features: - idx = self.stage_names.index(stage) - # we skip the stem - if idx == 0: - raise ValueError("The stem is not supported.") - hidden_state = hidden_states[idx] - (height, width) = outputs.hidden_states_spatial_dimensions[idx - 1] - norm = self.hidden_states_norms[idx - 1] + for i, (hidden_state, stage, (height, width)) in enumerate( + zip(hidden_states, self.stage_names[1:], spatial_dimensions) + ): + norm = self.hidden_states_norms[i] # the last element corespond to the layer's last block output but before patch merging hidden_state_unpolled = hidden_state[-1] hidden_state_norm = norm(hidden_state_unpolled) @@ -896,7 +894,8 @@ def forward( hidden_state_permuted = ( hidden_state_norm.permute(0, 2, 1).view((batch_size, hidden_size, height, width)).contiguous() ) - feature_maps += (hidden_state_permuted,) + if stage in self.out_features: + feature_maps += (hidden_state_permuted,) if not return_dict: output = (feature_maps,) diff --git a/src/transformers/models/nat/modeling_nat.py b/src/transformers/models/nat/modeling_nat.py index fac1d0edf328f9..278ed3d4b6bea2 100644 --- a/src/transformers/models/nat/modeling_nat.py +++ b/src/transformers/models/nat/modeling_nat.py @@ -934,16 +934,16 @@ def forward( hidden_states = outputs.reshaped_hidden_states feature_maps = () - for stage in self.out_features: - idx = self.stage_names.index(stage) - hidden_state = hidden_states[idx] - batch_size, num_channels, height, width = hidden_state.shape - hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous() - hidden_state = hidden_state.view(batch_size, height * width, num_channels) - hidden_state = self.hidden_states_norms[stage](hidden_state) - hidden_state = hidden_state.view(batch_size, height, width, num_channels) - hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous() - feature_maps += (hidden_state,) + for stage, hidden_state in zip(self.stage_names, hidden_states): + if stage in self.out_features: + # TODO can we simplify this? + batch_size, num_channels, height, width = hidden_state.shape + hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous() + hidden_state = hidden_state.view(batch_size, height * width, num_channels) + hidden_state = self.hidden_states_norms[stage](hidden_state) + hidden_state = hidden_state.view(batch_size, height, width, num_channels) + hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous() + feature_maps += (hidden_state,) if not return_dict: output = (feature_maps,) diff --git a/src/transformers/models/resnet/modeling_resnet.py b/src/transformers/models/resnet/modeling_resnet.py index be2f6d45019f24..df460d58f042b5 100644 --- a/src/transformers/models/resnet/modeling_resnet.py +++ b/src/transformers/models/resnet/modeling_resnet.py @@ -496,8 +496,9 @@ def forward( hidden_states = outputs.hidden_states feature_maps = () - for idx in self.out_indices: - feature_maps += (hidden_states[idx],) + for idx, stage in enumerate(self.stage_names): + if stage in self.out_features: + feature_maps += (hidden_states[idx],) if not return_dict: output = (feature_maps,) diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py index b8d556f38ec589..4fe4be5ac79a6d 100644 --- a/src/transformers/models/swin/modeling_swin.py +++ b/src/transformers/models/swin/modeling_swin.py @@ -1326,16 +1326,15 @@ def forward( hidden_states = outputs.reshaped_hidden_states feature_maps = () - for stage in self.out_features: - idx = self.stage_names.index(stage) - hidden_state = hidden_states[idx] - batch_size, num_channels, height, width = hidden_state.shape - hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous() - hidden_state = hidden_state.view(batch_size, height * width, num_channels) - hidden_state = self.hidden_states_norms[stage](hidden_state) - hidden_state = hidden_state.view(batch_size, height, width, num_channels) - hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous() - feature_maps += (hidden_state,) + for stage, hidden_state in zip(self.stage_names, hidden_states): + if stage in self.out_features: + batch_size, num_channels, height, width = hidden_state.shape + hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous() + hidden_state = hidden_state.view(batch_size, height * width, num_channels) + hidden_state = self.hidden_states_norms[stage](hidden_state) + hidden_state = hidden_state.view(batch_size, height, width, num_channels) + hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous() + feature_maps += (hidden_state,) if not return_dict: output = (feature_maps,) diff --git a/src/transformers/models/timm_backbone/modeling_timm_backbone.py b/src/transformers/models/timm_backbone/modeling_timm_backbone.py index ecd3baf98d667f..0c6fe67b75731f 100644 --- a/src/transformers/models/timm_backbone/modeling_timm_backbone.py +++ b/src/transformers/models/timm_backbone/modeling_timm_backbone.py @@ -77,10 +77,10 @@ def __init__(self, config, **kwargs): if getattr(config, "freeze_batch_norm_2d", False): self.freeze_batch_norm_2d() - # We force the backbone to return all of the hidden states and then filter them in the forward pass. - # This is to match the behavior of the other backbones in the library. - all_layers = {layer["module"]: str(i) for i, layer in enumerate(self._backbone.feature_info.info)} - self._backbone.return_layers = all_layers + # These are used to control the output of the model when called. If output_hidden_states is True, then + # return_layers is modified to include all layers. + self._return_layers = self._backbone.return_layers + self._all_layers = {layer["module"]: str(i) for i, layer in enumerate(self._backbone.feature_info.info)} super()._init_backbone(config) @classmethod @@ -136,10 +136,14 @@ def forward( if output_attentions: raise ValueError("Cannot output attentions for timm backbones at the moment") - hidden_states = self._backbone(pixel_values, **kwargs) - feature_maps = tuple(hidden_states[i] for i in self.out_indices) - - if not output_hidden_states: + if output_hidden_states: + # We modify the return layers to include all the stages of the backbone + self._backbone.return_layers = self._all_layers + hidden_states = self._backbone(pixel_values, **kwargs) + self._backbone.return_layers = self._return_layers + feature_maps = tuple(hidden_states[i] for i in self.out_indices) + else: + feature_maps = self._backbone(pixel_values, **kwargs) hidden_states = None feature_maps = tuple(feature_maps) diff --git a/src/transformers/models/vitdet/modeling_vitdet.py b/src/transformers/models/vitdet/modeling_vitdet.py index cebc766229813c..4015875f0c7e27 100644 --- a/src/transformers/models/vitdet/modeling_vitdet.py +++ b/src/transformers/models/vitdet/modeling_vitdet.py @@ -858,9 +858,9 @@ def forward( hidden_states = outputs.hidden_states if return_dict else outputs[1] feature_maps = () - for stage in self.out_features: - idx = self.stage_names.index(stage) - feature_maps += (hidden_states[idx],) + for stage, hidden_state in zip(self.stage_names, hidden_states): + if stage in self.out_features: + feature_maps += (hidden_state,) if not return_dict: if output_hidden_states: