From 80f7c3df46b479d351a23492fcd7bc12551308fd Mon Sep 17 00:00:00 2001 From: Fabian Isensee Date: Mon, 3 Jun 2024 09:29:15 +0200 Subject: [PATCH] disable torch.compile for mps, give clearer error messages, fix #2244 --- nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py b/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py index d3803f391..b23847cb2 100644 --- a/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py +++ b/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py @@ -238,13 +238,24 @@ def initialize(self): def _do_i_compile(self): # new default: compile is enabled! + # compile does not work on mps + if self.device == torch.device('mps'): + if 'nnUNet_compile' in os.environ.keys() and os.environ['nnUNet_compile'].lower() in ('true', '1', 't'): + self.print_to_log_file("INFO: torch.compile disabled because of unsupported mps device") + return False + # CPU compile crashes for 2D models. Not sure if we even want to support CPU compile!? Better disable if self.device == torch.device('cpu'): + if 'nnUNet_compile' in os.environ.keys() and os.environ['nnUNet_compile'].lower() in ('true', '1', 't'): + self.print_to_log_file("INFO: torch.compile disabled because device is CPU") return False # default torch.compile doesn't work on windows because there are apparently no triton wheels for it # https://discuss.pytorch.org/t/windows-support-timeline-for-torch-compile/182268/2 if os.name == 'nt': + if 'nnUNet_compile' in os.environ.keys() and os.environ['nnUNet_compile'].lower() in ('true', '1', 't'): + self.print_to_log_file("INFO: torch.compile disabled because Windows is not natively supported. If " + "you know what you are doing, check https://discuss.pytorch.org/t/windows-support-timeline-for-torch-compile/182268/2") return False if 'nnUNet_compile' not in os.environ.keys():