diff --git a/gradio_demo.py b/gradio_demo.py index d2aa68d..9c01b0c 100644 --- a/gradio_demo.py +++ b/gradio_demo.py @@ -18,11 +18,11 @@ log = logging.getLogger() -device = "cpu" +device = 'cpu' if torch.backends.mps.is_available(): - device = "mps" + device = 'mps' elif torch.cuda.is_available(): - device = "cuda" + device = 'cuda' dtype = torch.bfloat16