From 61f6d3932653b123f9acec232962e537f5f4dd07 Mon Sep 17 00:00:00 2001 From: e-petrachi Date: Tue, 6 Aug 2024 15:10:59 +0200 Subject: [PATCH] Added compatibility with Apple Metal Performance Shaders (MPS) chips for accelerated inference on PyTorch --- cog_predict.py | 3 ++- inference_realesrgan_video.py | 6 ++++-- realesrgan/utils.py | 8 +++++--- tests/test_discriminator_arch.py | 5 +++++ 4 files changed, 16 insertions(+), 6 deletions(-) diff --git a/cog_predict.py b/cog_predict.py index fa0f89dfd..e779bc80d 100644 --- a/cog_predict.py +++ b/cog_predict.py @@ -49,7 +49,8 @@ def setup(self): ) def choose_model(self, scale, version, tile=0): - half = True if torch.cuda.is_available() else False + # added apple chip mps supports + half = True if torch.mps.is_available() or torch.cuda.is_available() else False if version == 'General - RealESRGANplus': model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) model_path = 'weights/RealESRGAN_x4plus.pth' diff --git a/inference_realesrgan_video.py b/inference_realesrgan_video.py index c3c4d1465..0118a701a 100644 --- a/inference_realesrgan_video.py +++ b/inference_realesrgan_video.py @@ -269,7 +269,8 @@ def inference_video(args, video_save_path, device=None, total_workers=1, worker_ else: writer.write_frame(output) - torch.cuda.synchronize(device) + # added apple chip mps supports + torch.cuda.synchronize(device) if torch.cuda.is_available() else torch.mps.synchronize() pbar.update(1) reader.close() @@ -286,7 +287,8 @@ def run(args): os.system(f'ffmpeg -i {args.input} -qscale:v 1 -qmin 1 -qmax 1 -vsync 0 {tmp_frames_folder}/frame%08d.png') args.input = tmp_frames_folder - num_gpus = torch.cuda.device_count() + # added apple chip mps supports + num_gpus = torch.cuda.device_count() if torch.cuda.is_available() else torch.mps.device_count() num_process = num_gpus * args.num_process_per_gpu if num_process == 1: inference_video(args, video_save_path) diff --git a/realesrgan/utils.py b/realesrgan/utils.py index 67e5232d6..bdcf54abf 100644 --- a/realesrgan/utils.py +++ b/realesrgan/utils.py @@ -45,9 +45,11 @@ def __init__(self, self.half = half # initialize model - if gpu_id: - self.device = torch.device( - f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') if device is None else device + if torch.backends.mps.is_available(): + # added apple chip mps supports + self.device = torch.device('mps') if device is None else device + elif gpu_id: + self.device = torch.device(f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') if device is None else device else: self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device diff --git a/tests/test_discriminator_arch.py b/tests/test_discriminator_arch.py index c56a40c77..6891e29ff 100644 --- a/tests/test_discriminator_arch.py +++ b/tests/test_discriminator_arch.py @@ -17,3 +17,8 @@ def test_unetdiscriminatorsn(): net.cuda() output = net(img.cuda()) assert output.shape == (1, 1, 32, 32) + # added apple chip mps supports + elif torch.mps.is_available(): + net.mps() + output = net(img.mps()) + assert output.shape == (1, 1, 32, 32)