Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Align backbone stage selection with out_indices & out_features #27606

Merged
merged 10 commits into from
Dec 20, 2023

Conversation

amyeroberts
Copy link
Collaborator

@amyeroberts amyeroberts commented Nov 20, 2023

What does this PR do?

This PR adds a set of input verification for the out_features and out_indices arguments for backbones, making sure that any accepted values align with the returned model outputs.

More details

out_features and out_indices are used to specify which blocks' attentions are returned by the Backbone classes.

The following can currently be passed in out_features:

  • Out-of-order stages: ["stage5", "stage2", "stage4"]
  • Double stages: ["stage3", "stage3"]

However, this is will not be reflected in the returned feature maps on a forward pass e.g. here for ResNet. The feature maps are selected by iterating over the stage_names (ordered list of all stages in the backbone) and returning those that have their name in out_features and so are in stage-order and will only be selected once.

There is also a misalignment between the TimmBackbone and transformers backbones - as timm will automatically take the set of indices (removing duplicates) whereas transformers will keep them in the out_indices attribute.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! Could you elaborate on the motivation behind the out of order or duplicated (seems to be a choice rather than a bug fix for me no?)

if stage in self.out_features:
feature_maps += (hidden_states[idx],)
for idx in self.out_indices:
feature_maps += (hidden_states[idx],)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we not want to check the index here as well?

hidden_state = self.hidden_states_norms[stage](hidden_state)
feature_maps += (hidden_state,)
for stage in self.out_features:
idx = self.stage_names.index(stage)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I heard that .index is only available starting python 3.6 unless the dict was always ordered

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stage_names is a list and we only support py >= 3.8 so I think it's safe :)

@amyeroberts
Copy link
Collaborator Author

Hey! Could you elaborate on the motivation behind the out of order or duplicated (seems to be a choice rather than a bug fix for me no?)

@ArthurZucker Sure! It's both - a choice and a current bug. The choice is whether we allow passing in different orders and duplicates and the bug is whether this is reflected. At the moment I can pass in duplicates, out-of-order etc. but it won't be reflected in the returned stages. Another option is for input verification where we raise an error if the user chooses out_features or out_indices which have these properties. I could implement that instead? It might be a bit more defensive

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright thanks for explaining! No need for the check down with this fix 🔥

@amyeroberts
Copy link
Collaborator Author

@ArthurZucker Sorry to flip-flop. I've thought a bit more and concluded that not allowing repetitions & different orders when setting out_features and out_indices would be better:

  • It's possible to enable more flexible arguments later in the future but not the other way around - this wouldn't be backward compatible
  • Adding checks is backwards compatible: new errors might be raised with existing inputs but these would start flagging unexpected behaviour
  • Having multiple or out-of-order arguments is something the user can handle on their side after receiving the outputs

I'm going to update the PR to add these checks + relevant tests instead

@amyeroberts amyeroberts changed the title Fix backbone forward stage selection Aligb backbone stage selection and out_indices/out_features verification Nov 22, 2023
@amyeroberts amyeroberts changed the title Aligb backbone stage selection and out_indices/out_features verification Align backbone stage selection and out_indices/out_features verification Nov 22, 2023
@amyeroberts amyeroberts changed the title Align backbone stage selection and out_indices/out_features verification Align backbone stage selection with out_indices & out_features Nov 22, 2023
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Let's document this somewhere to make this change visible. LGTM !

@amyeroberts amyeroberts force-pushed the fix-backbone-forward-bug branch from 389668c to f5b5db4 Compare November 28, 2023 13:59
@amyeroberts
Copy link
Collaborator Author

@ArthurZucker There's isn't any proper documentation for the backbones atm - this is being added in #27456. I've added notes about the restrictions in the docstrings

@amyeroberts amyeroberts force-pushed the fix-backbone-forward-bug branch from 889b42e to 5c194eb Compare December 12, 2023 20:05
@amyeroberts amyeroberts merged commit ee298a1 into huggingface:main Dec 20, 2023
18 checks passed
@amyeroberts amyeroberts deleted the fix-backbone-forward-bug branch December 20, 2023 18:33
staghado pushed a commit to staghado/transformers that referenced this pull request Jan 15, 2024
…ngface#27606)

* Iteratre over out_features instead of stage_names

* Update for all backbones

* Add tests

* Fix

* Align timm backbone behaviour with other backbones

* Fix tests

* Stricter checks on set out_features and out_indices

* Revert back stage selection logic

* Remove out-of-order logic

* Document restriction in docstrings
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants