From d294fb839d845a596a0d5c63877463799b5570e5 Mon Sep 17 00:00:00 2001 From: Rex Cheng Date: Mon, 23 Dec 2024 00:01:32 -0600 Subject: [PATCH] cuda takes priority --- demo.py | 6 +++--- gradio_demo.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/demo.py b/demo.py index 8b9f5e4..9f073d6 100644 --- a/demo.py +++ b/demo.py @@ -63,10 +63,10 @@ def main(): mask_away_clip: bool = args.mask_away_clip device = 'cpu' - if torch.backends.mps.is_available(): - device = 'mps' - elif torch.cuda.is_available(): + if torch.cuda.is_available(): device = 'cuda' + elif torch.backends.mps.is_available(): + device = 'mps' else: log.warning('CUDA/MPS are not available, running on CPU') dtype = torch.float32 if args.full_precision else torch.bfloat16 diff --git a/gradio_demo.py b/gradio_demo.py index 8672bae..25e2747 100644 --- a/gradio_demo.py +++ b/gradio_demo.py @@ -19,10 +19,10 @@ log = logging.getLogger() device = 'cpu' -if torch.backends.mps.is_available(): - device = 'mps' -elif torch.cuda.is_available(): +if torch.cuda.is_available(): device = 'cuda' +elif torch.backends.mps.is_available(): + device = 'mps' else: log.warning('CUDA/MPS are not available, running on CPU') dtype = torch.bfloat16