Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow to compile collective for PT > 2.3 #6674

Closed
wants to merge 22 commits into from
Closed
Changes from 2 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
236d13d
Allow to compile collective for PT > 2.3
nelyahu Oct 1, 2024
2ee4047
Merge branch 'master' into compile_collectives
nelyahu Oct 27, 2024
0ce079a
fix formatting
nelyahu Oct 28, 2024
71a2507
Merge branch 'master' into compile_collectives
loadams Oct 28, 2024
d8fd5c4
Merge branch 'master' into compile_collectives
loadams Oct 28, 2024
2cb1e6c
Merge branch 'master' into compile_collectives
loadams Oct 30, 2024
f3b00c7
Merge branch 'master' into compile_collectives
loadams Oct 30, 2024
ba500dd
Merge branch 'master' into compile_collectives
loadams Oct 31, 2024
beacedb
Merge branch 'master' into compile_collectives
loadams Oct 31, 2024
12124db
Merge branch 'master' into compile_collectives
loadams Nov 5, 2024
615dc47
Merge branch 'master' into compile_collectives
loadams Nov 6, 2024
a91f78a
Merge branch 'master' into compile_collectives
loadams Nov 6, 2024
54d9f1f
Merge branch 'master' into compile_collectives
loadams Nov 7, 2024
9a979a0
Merge branch 'master' into compile_collectives
loadams Nov 11, 2024
e6b4969
Merge branch 'master' into compile_collectives
loadams Nov 12, 2024
442e520
Merge branch 'master' into compile_collectives
loadams Nov 14, 2024
e4bc54f
Merge branch 'master' into compile_collectives
loadams Nov 14, 2024
733e0f2
Merge branch 'master' into compile_collectives
loadams Nov 15, 2024
fa57334
Merge branch 'master' into compile_collectives
loadams Nov 21, 2024
b63c9c8
Merge branch 'master' into compile_collectives
loadams Dec 11, 2024
685a71b
Merge branch 'master' into compile_collectives
loadams Dec 11, 2024
b568a26
Merge branch 'master' into compile_collectives
loadams Dec 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 27 additions & 23 deletions deepspeed/comm/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
DS_COMM_ALL_REDUCE_OFF = False
DS_COMM_REDUCE_OFF = False

def disable_compiler_collective(func):
nelyahu marked this conversation as resolved.
Show resolved Hide resolved
if required_torch_version(min_version=2.3):
return func
return compiler.disable(func)

def build_shm_op():
builder = get_accelerator().create_op_builder("ShareMemCommBuilder")
Expand Down Expand Up @@ -114,7 +118,7 @@ def __init__(self, backend, timeout, init_method, rank=-1, world_size=-1, name='
self.shm_comm_op.initialize(self.get_world_size(), self.get_rank())

@classmethod
@compiler.disable
@disable_compiler_collective
def get_all_gather_function(self):
if hasattr(torch.distributed, "all_gather_into_tensor"):
return torch.distributed.all_gather_into_tensor
Expand All @@ -123,7 +127,7 @@ def get_all_gather_function(self):
return None

@classmethod
@compiler.disable
@disable_compiler_collective
def get_reduce_scatter_function(self):
if hasattr(torch.distributed, "reduce_scatter_tensor"):
return torch.distributed.reduce_scatter_tensor
Expand All @@ -146,19 +150,19 @@ def init_process_group(self, backend, timeout, init_method, rank, world_size):
world_size=world_size)
self.using_mpi = torch.distributed.get_backend() == 'mpi'

@compiler.disable
@disable_compiler_collective
def all_reduce(self, tensor, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False):
op = self._reduce_op(op)
return torch.distributed.all_reduce(tensor=tensor, op=op, group=group, async_op=async_op)

def inference_all_reduce(self, tensor, op, group=None):
if not hasattr(torch.ops, 'deepspeed') or not hasattr(torch.ops.deepspeed, 'inference_all_reduce_'):
if self.shm_comm_op == None or self.shm_comm_op.inference_all_reduce(tensor, op) == -1:
op = self._reduce_op(op)
return torch.distributed.all_reduce(tensor=tensor, op=op, group=group, async_op=False)
else:
return torch.ops.deepspeed.inference_all_reduce_(tensor)

@compiler.disable
@disable_compiler_collective
def all_reduce_coalesced(self, tensors, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False):
""" proxy func to torch.distributed.all_reduce_coalesced,
which is included in PyTorch 1.13 and above
Expand All @@ -169,15 +173,15 @@ def all_reduce_coalesced(self, tensors, op=torch.distributed.ReduceOp.SUM, group
op = self._reduce_op(op)
return torch.distributed.all_reduce_coalesced(tensors=tensors, op=op, group=group, async_op=async_op)

@compiler.disable
@disable_compiler_collective
def reduce(self, tensor, dst, op=ReduceOp.SUM, group=None, async_op=False):
if DS_COMM_REDUCE_OFF:
if int(os.getenv('RANK', '0')) == 0:
utils.logger.warning("REDUCE is OFF")
return Noop()
return torch.distributed.reduce(tensor=tensor, dst=dst, op=self._reduce_op(op), group=group, async_op=async_op)

@compiler.disable
@disable_compiler_collective
def reduce_scatter(self, output, input_list, op=ReduceOp.SUM, group=None, async_op=False):
if DS_COMM_REDUCE_SCATTER_OFF:
if int(os.getenv('RANK', '0')) == 0:
Expand All @@ -190,7 +194,7 @@ def reduce_scatter(self, output, input_list, op=ReduceOp.SUM, group=None, async_
group=group,
async_op=async_op)

@compiler.disable
@disable_compiler_collective
def broadcast(self, tensor, src, group=None, async_op=False):
if DS_COMM_BROADCAST_OFF:
if int(os.getenv('RANK', '0')) == 0:
Expand All @@ -199,7 +203,7 @@ def broadcast(self, tensor, src, group=None, async_op=False):
else:
return torch.distributed.broadcast(tensor=tensor, src=src, group=group, async_op=async_op)

@compiler.disable
@disable_compiler_collective
def all_gather(self, tensor_list, tensor, group=None, async_op=False):
if DS_COMM_ALL_GATHER_OFF:
if int(os.getenv('RANK', '0')) == 0:
Expand All @@ -208,15 +212,15 @@ def all_gather(self, tensor_list, tensor, group=None, async_op=False):
else:
return torch.distributed.all_gather(tensor_list=tensor_list, tensor=tensor, group=group, async_op=async_op)

@compiler.disable
@disable_compiler_collective
def all_gather_into_tensor(self, output_tensor, input_tensor, group=None, async_op=False):
if self.has_all_gather_into_tensor():
return self.all_gather_function(output_tensor=output_tensor,
input_tensor=input_tensor,
group=group,
async_op=async_op)

@compiler.disable
@disable_compiler_collective
def all_gather_base(self, output_tensor, input_tensor, group=None, async_op=False):
if DS_COMM_ALL_GATHER_OFF:
if int(os.getenv('RANK', '0')) == 0:
Expand All @@ -234,7 +238,7 @@ def all_gather_base(self, output_tensor, input_tensor, group=None, async_op=Fals
"please consider upgrading your pytorch installation.")
pass

@compiler.disable
@disable_compiler_collective
def all_gather_coalesced(self, output_tensors, input_tensors, group=None, async_op=False):
""""""
assert len(output_tensors) == len(input_tensors), ""
Expand All @@ -258,7 +262,7 @@ def all_gather_coalesced(self, output_tensors, input_tensors, group=None, async_
else:
reqs[-1].wait()

@compiler.disable
@disable_compiler_collective
def reduce_scatter_tensor(self, output_tensor, input_tensor, op=ReduceOp.SUM, group=None, async_op=False):
if self.has_reduce_scatter_tensor():
return self.reduce_scatter_function(output_tensor,
Expand All @@ -272,7 +276,7 @@ def reduce_scatter_tensor(self, output_tensor, input_tensor, op=ReduceOp.SUM, gr
"please consider upgrading your pytorch installation.")
pass

@compiler.disable
@disable_compiler_collective
def all_to_all_single(self,
output,
input,
Expand All @@ -287,49 +291,49 @@ def all_to_all_single(self,
group=group,
async_op=async_op)

@compiler.disable
@disable_compiler_collective
def all_to_all(self, output_tensor_list, input_tensor_list, group=None, async_op=False):
return torch.distributed.all_to_all(output_tensor_list, input_tensor_list, group=group, async_op=async_op)

@compiler.disable
@disable_compiler_collective
def send(self, tensor, dst, group=None, tag=0):
return torch.distributed.send(tensor=tensor, dst=dst, group=group, tag=tag)

@compiler.disable
@disable_compiler_collective
def recv(self, tensor, src=None, group=None, tag=0):
return torch.distributed.recv(tensor=tensor, src=src, group=group, tag=tag)

@compiler.disable
@disable_compiler_collective
def isend(self, tensor, dst, group=None, tag=0):
return torch.distributed.isend(tensor=tensor, dst=dst, group=group, tag=tag)

@compiler.disable
@disable_compiler_collective
def irecv(self, tensor, src=None, group=None, tag=0):
return torch.distributed.irecv(tensor=tensor, src=src, group=group, tag=tag)

@compiler.disable
@disable_compiler_collective
def gather(self, tensor, gather_list=None, dst=0, group=None, async_op=False):
return torch.distributed.gather(tensor=tensor,
gather_list=gather_list,
dst=dst,
group=group,
async_op=async_op)

@compiler.disable
@disable_compiler_collective
def scatter(self, tensor, scatter_list=None, src=0, group=None, async_op=False):
return torch.distributed.scatter(tensor=tensor,
scatter_list=scatter_list,
src=src,
group=group,
async_op=async_op)

@compiler.disable
@disable_compiler_collective
def barrier(self, group=torch.distributed.GroupMember.WORLD, async_op=False, device_ids=None):
if group is None:
group = torch.distributed.GroupMember.WORLD
return torch.distributed.barrier(group=group, async_op=async_op, device_ids=device_ids)

@compiler.disable
@disable_compiler_collective
def monitored_barrier(self, group=torch.distributed.GroupMember.WORLD, timeout=None, wait_all_ranks=False):
if group is None:
group = torch.distributed.GroupMember.WORLD
Expand Down
Loading