Skip to content

Commit

Permalink
Move device check ahead of dist check
Browse files Browse the repository at this point in the history
  • Loading branch information
rwightman committed Oct 21, 2024
1 parent 0965313 commit 7b802a2
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions src/open_clip_train/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,12 @@ def init_distributed_device_so(
global_rank = 0
local_rank = 0
device_type, *device_idx = device.split(':', maxsplit=1)
is_avail, is_known = is_device_available(device_type)
if not is_known:
warnings.warn(f"Device {device} was not known and checked for availability, trying anyways.")
elif not is_avail:
warnings.warn(f"Device {device} was not available, falling back to CPU.")
device_type = device = 'cpu'

if horovod:
import horovod.torch as hvd
Expand Down Expand Up @@ -172,13 +178,6 @@ def init_distributed_device_so(
global_rank = torch.distributed.get_rank()
distributed = True

is_avail, is_known = is_device_available(device_type)
if not is_known:
warnings.warn(f"Device {device} was not known and checked for availability, trying anyways.")
elif not is_avail:
warnings.warn(f"Device {device} was not available, falling back to CPU.")
device_type = device = 'cpu'

if distributed and not no_set_device_rank and device_type not in ('cpu', 'mps'):
# Ignore manually specified device index in distributed mode and
# override with resolved local rank, fewer headaches in most setups.
Expand Down

0 comments on commit 7b802a2

Please sign in to comment.