Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added compatibility with Apple Metal Performance Shaders (MPS) chips #836

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion cog_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ def setup(self):
)

def choose_model(self, scale, version, tile=0):
half = True if torch.cuda.is_available() else False
# added apple chip mps supports
half = True if torch.mps.is_available() or torch.cuda.is_available() else False
if version == 'General - RealESRGANplus':
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
model_path = 'weights/RealESRGAN_x4plus.pth'
Expand Down
6 changes: 4 additions & 2 deletions inference_realesrgan_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,8 @@ def inference_video(args, video_save_path, device=None, total_workers=1, worker_
else:
writer.write_frame(output)

torch.cuda.synchronize(device)
# added apple chip mps supports
torch.cuda.synchronize(device) if torch.cuda.is_available() else torch.mps.synchronize()
pbar.update(1)

reader.close()
Expand All @@ -286,7 +287,8 @@ def run(args):
os.system(f'ffmpeg -i {args.input} -qscale:v 1 -qmin 1 -qmax 1 -vsync 0 {tmp_frames_folder}/frame%08d.png')
args.input = tmp_frames_folder

num_gpus = torch.cuda.device_count()
# added apple chip mps supports
num_gpus = torch.cuda.device_count() if torch.cuda.is_available() else torch.mps.device_count()
num_process = num_gpus * args.num_process_per_gpu
if num_process == 1:
inference_video(args, video_save_path)
Expand Down
8 changes: 5 additions & 3 deletions realesrgan/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,11 @@ def __init__(self,
self.half = half

# initialize model
if gpu_id:
self.device = torch.device(
f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') if device is None else device
if torch.backends.mps.is_available():
# added apple chip mps supports
self.device = torch.device('mps') if device is None else device
elif gpu_id:
self.device = torch.device(f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') if device is None else device
else:
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device

Expand Down
5 changes: 5 additions & 0 deletions tests/test_discriminator_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,8 @@ def test_unetdiscriminatorsn():
net.cuda()
output = net(img.cuda())
assert output.shape == (1, 1, 32, 32)
# added apple chip mps supports
elif torch.mps.is_available():
net.mps()
output = net(img.mps())
assert output.shape == (1, 1, 32, 32)