Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improving nnUNet inference speed #2048

Merged
merged 8 commits into from
Apr 11, 2024
Merged

Conversation

ancestor-mithril
Copy link
Contributor

Main additions are:

  • Using torch.backends.cudnn.benchmark = True (used during training, but forgot during inference).
  • Replacing torch.no_grad context with a single torch.inference_mode context.
  • Using inplace operations for calculating and using the gaussian.

@FabianIsensee FabianIsensee self-assigned this Mar 28, 2024
nnunetv2/inference/predict_from_raw_data.py Outdated Show resolved Hide resolved
nnunetv2/inference/predict_from_raw_data.py Outdated Show resolved Hide resolved
nnunetv2/inference/predict_from_raw_data.py Outdated Show resolved Hide resolved
nnunetv2/inference/sliding_window_prediction.py Outdated Show resolved Hide resolved
* Added n_predictions back to replicate the previous behavior
* Setting gaussian to 1 if not using the gaussian
* setting lru cache size back to 2 to prevent OOM for unintended usage
@FabianIsensee FabianIsensee merged commit 3a2d870 into MIC-DKFZ:master Apr 11, 2024
1 check passed
@FabianIsensee
Copy link
Member

Thanks! Had to make two small changes to make sure everything works + yields the same results.
What would you say is the advantage of @torch.inference_mode() over with torch.no_grad()? Prediction speed seems to be the same

@ancestor-mithril ancestor-mithril deleted the dev2 branch April 11, 2024 13:46
@ancestor-mithril
Copy link
Contributor Author

Tensors created during inference mode are slimmer because they do not have a version counter and requires_grad can't be set to True anymore. More information here: https://stackoverflow.com/a/74197846/18441695.
In my tests, @torch.inference_mode() is slightly less than 1% faster than no_grad for nnUNet.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants