Skip to content

Commit

Permalink
Generalize sharding spec for ShardedFlatTensor (#15)
Browse files Browse the repository at this point in the history
* update docstrings

* generalize checkpointer to accept multiple offsets per rank

* support multiple offsets per shard with `ShardedFlatTensor`

* fix

* fix shardedflatparameter test

* fix checkpoint test

* fix init device for sharded flat tensor

* fix

* more logging

* add option to unshard with threads

* fix debug logging

* check for nans earlier
  • Loading branch information
epwalsh authored May 7, 2024
1 parent 70ee999 commit ae42293
Show file tree
Hide file tree
Showing 7 changed files with 256 additions and 171 deletions.
240 changes: 135 additions & 105 deletions src/olmo_core/distributed/checkpoint.py

Large diffs are not rendered by default.

24 changes: 14 additions & 10 deletions src/olmo_core/distributed/fsdp/flat_param_handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def shard_params(
if not params:
params_data = ShardedFlatTensor(torch.empty(0, device=device))
params_data.mark_as_sharded(
ShardingSpec(unsharded_shape=(0,), unsharded_flattened_offsets=tuple([(0, 0)] * world_size))
ShardingSpec(unsharded_shape=(0,), unsharded_flattened_offsets=tuple([((0, 0),)] * world_size))
)
return FlatParamHandle(process_group=process_group, device=device)

Expand Down Expand Up @@ -151,34 +151,34 @@ def shard_params(
flat_param_global_offsets = (numel_running_total, numel_running_total + param.numel())

# First we need to determine which ranks will have a slice of the data.
unsharded_flattened_offsets: List[Tuple[int, int]] = []
unsharded_flattened_offsets: List[Tuple[Tuple[int, int], ...]] = []
for rank in range(world_size):
rank_global_start = rank * padded_sharded_numel
rank_global_end = rank_global_start + padded_sharded_numel
if (rank_global_end <= flat_param_global_offsets[0]) or (
flat_param_global_offsets[1] <= rank_global_start
):
# No overlap with this rank.
unsharded_flattened_offsets.append((0, 0))
unsharded_flattened_offsets.append(((0, 0),))
elif (
rank_global_start <= flat_param_global_offsets[0]
and flat_param_global_offsets[1] <= rank_global_end
):
# Param is completely contained by this rank.
unsharded_flattened_offsets.append((0, param.numel()))
unsharded_flattened_offsets.append(((0, param.numel()),))
elif (
rank_global_start <= flat_param_global_offsets[0]
and rank_global_end < flat_param_global_offsets[1]
):
# Param starts in this rank and ends in a subsequent rank.
unsharded_flattened_offsets.append((0, rank_global_end - flat_param_global_offsets[0]))
unsharded_flattened_offsets.append(((0, rank_global_end - flat_param_global_offsets[0]),))
elif (
flat_param_global_offsets[0] < rank_global_start
and flat_param_global_offsets[1] <= rank_global_end
):
# Param starts in a previous rank and ends in this one.
unsharded_flattened_offsets.append(
(rank_global_start - flat_param_global_offsets[0], param.numel())
((rank_global_start - flat_param_global_offsets[0], param.numel()),)
)
elif (
flat_param_global_offsets[0] < rank_global_start
Expand All @@ -187,8 +187,10 @@ def shard_params(
# Param spans this rank and overflows into other ranks.
unsharded_flattened_offsets.append(
(
rank_global_start - flat_param_global_offsets[0],
rank_global_end - flat_param_global_offsets[0],
(
rank_global_start - flat_param_global_offsets[0],
rank_global_end - flat_param_global_offsets[0],
),
)
)

Expand All @@ -202,7 +204,9 @@ def shard_params(
else:
flat_param = ShardedFlatParameter(
param.data.flatten()[
unsharded_flattened_offsets[local_rank][0] : unsharded_flattened_offsets[local_rank][1]
unsharded_flattened_offsets[local_rank][0][0] : unsharded_flattened_offsets[
local_rank
][0][1]
].to(device)
)
else:
Expand All @@ -229,7 +233,7 @@ def shard_params(
unsharded_shape=(padded_unsharded_numel,),
unsharded_flattened_offsets=tuple(
[
(start_idx, end_idx)
((start_idx, end_idx),)
for start_idx, end_idx in zip(
range(0, padded_unsharded_numel, padded_sharded_numel),
range(padded_sharded_numel, padded_unsharded_numel + 1, padded_sharded_numel),
Expand Down
3 changes: 1 addition & 2 deletions src/olmo_core/distributed/fsdp/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,8 +545,7 @@ def collect_children(module: nn.Module, prefix: str = "") -> Generator[Tuple[str

def _managed_named_parameters(self) -> Generator[Tuple[str, ShardedFlatParameter], None, None]:
"""
Returns a generator over all parameters managed by this FSDP instance. This is equivalent
to `self.module.named_parameters()` except that parameters within nested FSDP instances are omitted.
Returns a generator over all parameters directly managed by this FSDP instance.
"""
for handle in self.state.flat_param_handles:
for param_name, param in zip(handle.param_fqns, handle.params):
Expand Down
137 changes: 94 additions & 43 deletions src/olmo_core/distributed/tensors/sharded_flat_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,37 +25,58 @@ class ShardingSpec:
The shape of the full unsharded (unflattened) parameter.
"""

unsharded_flattened_offsets: Tuple[Tuple[int, int], ...]
unsharded_flattened_offsets: Tuple[Tuple[Tuple[int, int], ...], ...]
"""
The ``(start_idx, end_idx)`` within the full unsharded flattened parameter that each local shard
within the process group corresponds to.
The offsets (``(start_idx, end_idx)``) within the full unsharded flattened parameter that each
local shard within the process group corresponds to.
This tuple is indexed by rank. For example, the ``(start_idx, end_idx)`` within the full unsharded flattened
parameter for the local shard of the current rank is given by ``unsharded_flattened_offsets[dist.get_rank(process_group)]``.
This tuple is indexed by rank within the process group.
For example, the offsets within the full unsharded flattened parameter for the
local shard of the current rank is given by ``unsharded_flattened_offsets[dist.get_rank(process_group)]``.
"""

def __post_init__(self):
numel_accounted_for = 0
for offsets in self.unsharded_flattened_offsets:
assert offsets[0] <= offsets[1]
numel_accounted_for += offsets[1] - offsets[0]
for rank_offsets in self.unsharded_flattened_offsets:
for start_idx, end_idx in rank_offsets:
assert start_idx <= end_idx
numel_accounted_for += end_idx - start_idx
if numel_accounted_for != self.unsharded_numel:
raise ValueError(f"invalid sharding spec {self}")

@property
def unsharded_numel(self) -> int:
"""
The number of elements in the full unsharded tensor.
"""
return reduce(lambda x, y: x * y, self.unsharded_shape, 1)

@property
def sharded_numels(self) -> Tuple[int, ...]:
return tuple((offsets[1] - offsets[0] for offsets in self.unsharded_flattened_offsets))
"""
The number of elements in each shard.
"""
return tuple(
(
sum(end_idx - start_idx for start_idx, end_idx in offsets)
for offsets in self.unsharded_flattened_offsets
)
)

@property
def unsharded_flattened_shape(self) -> Tuple[int, ...]:
"""
The shape of the unsharded flattened tensor.
"""
return (self.unsharded_numel,)


class ShardedFlatTensor(torch.Tensor):
"""
:class:`ShardedFlatTensor` represents a sharded tensor with the assumption that every shard is
a contiguous slice into the flattened unsharded tensor.
"""

SHARDED_FLAT_TENSOR_METADATA_NAME = "__sharded_metadata__"
SHARDED_FLAT_TENSOR_SHARDING_SPEC_KEY = "sharding_spec"
SHARDED_FLAT_TENSOR_PROCESS_GROUP_KEY = "process_group"
Expand Down Expand Up @@ -104,7 +125,7 @@ def _gather_data(self, dtype: Optional[torch.dtype] = None, rank0_only: bool = F
local_flat_padded_tensor = F.pad(self.data.to(dtype or self.dtype), local_padding)

# Pad sharded tensors to the same size.
if not rank0_only or get_rank(group=self.process_group) == 0:
if not rank0_only or local_rank == 0:
flat_sharded_tensor_list = [
torch.empty(max_numel, device=self.device, dtype=dtype or self.dtype)
for _ in range(len(self.sharding_spec.sharded_numels) - 1)
Expand All @@ -123,21 +144,38 @@ def _gather_data(self, dtype: Optional[torch.dtype] = None, rank0_only: bool = F
# Gather padded sharded tensors to rank 0.
dist.gather(local_flat_padded_tensor, gather_list=flat_sharded_tensor_list, group=self.process_group)

# Unpad, sort by starting offset, and concatenate.
if flat_sharded_tensor_list is not None:
flat_tensor = torch.cat(
[
flat_sharded_tensor_list[idx][: sharded_numels[idx]]
for idx in sorted(
range(len(sharded_numels)),
key=lambda idx: self.sharding_spec.unsharded_flattened_offsets[idx][0],
)
]
)
return flat_tensor.reshape(self.sharding_spec.unsharded_shape)
else:
if flat_sharded_tensor_list is None:
# rank0_only=True and this is not rank 0.
return torch.empty(0, dtype=dtype or self.dtype, device=self.device)

# Unpad and pull out contiguous sharded chunks from each ranks sharded flat tensor.
contiguous_flat_sharded_tensors = []
contiguous_offsets = []
for rank, rank_sharded_tensor in enumerate(flat_sharded_tensor_list):
rank_sharded_tensor = rank_sharded_tensor[: sharded_numels[rank]]
local_offset = 0
for start_idx, end_idx in self.sharding_spec.unsharded_flattened_offsets[rank]:
chunk_numel = end_idx - start_idx
contiguous_flat_sharded_tensors.append(
rank_sharded_tensor[local_offset : local_offset + chunk_numel]
)
contiguous_offsets.append((start_idx, end_idx))
local_offset += chunk_numel

# Now sort by starting offset and concatenate together.
flat_tensor = torch.cat(
[
contiguous_flat_sharded_tensors[idx]
for idx in sorted(
range(len(contiguous_offsets)),
key=lambda idx: contiguous_offsets[idx][0],
)
]
)

# Reshape and return.
return flat_tensor.reshape(self.sharding_spec.unsharded_shape)

@classmethod
def shard(
cls: Type[T],
Expand All @@ -157,6 +195,8 @@ def shard(
)

tensor_is_initialized = tensor.device != torch.device("meta")
if device is None and tensor_is_initialized:
device = tensor.device

if synchronize and tensor_is_initialized:
if device is not None:
Expand All @@ -172,19 +212,23 @@ def shard(
world_size = get_world_size(group=process_group)
shard_max_numel = math.ceil(tensor.numel() / world_size)
all_offsets = tuple(
(rank * shard_max_numel, min((rank + 1) * shard_max_numel, tensor.numel()))
((rank * shard_max_numel, min((rank + 1) * shard_max_numel, tensor.numel())),)
for rank in range(world_size)
)
sharding_spec = ShardingSpec(
unsharded_shape=tuple(tensor.shape), unsharded_flattened_offsets=all_offsets
)

offsets = sharding_spec.unsharded_flattened_offsets[get_rank(group=process_group)]

sharded_tensor = torch.empty(
sharding_spec.sharded_numels[get_rank(group=process_group)], device=device, dtype=tensor.dtype
)
if tensor_is_initialized:
sharded_tensor = tensor.flatten()[offsets[0] : offsets[1]].clone().to(device=device)
else:
sharded_tensor = torch.empty(offsets[1] - offsets[0], device=device, dtype=tensor.dtype)
flat_tensor = tensor.flatten()
start_offset = 0
for start_idx, end_idx in sharding_spec.unsharded_flattened_offsets[get_rank(group=process_group)]:
chunk_numel = end_idx - start_idx
sharded_tensor[start_offset : start_offset + chunk_numel].copy_(flat_tensor[start_idx:end_idx])
start_offset += chunk_numel

sharded_tensor = cls( # type: ignore
sharded_tensor, requires_grad=requires_grad if requires_grad is not None else tensor.requires_grad
Expand All @@ -196,8 +240,11 @@ def gather(self, dtype: Optional[torch.dtype] = None, rank0_only: bool = False)
"""
Gather the sharded flat parameter across a process group into a full unsharded parameter.
"""
unsharded_data = self._gather_data(dtype=dtype, rank0_only=rank0_only)
unsharded_data.requires_grad = self.requires_grad
if self.is_sharded:
unsharded_data = self._gather_data(dtype=dtype, rank0_only=rank0_only)
unsharded_data.requires_grad = self.requires_grad
else:
unsharded_data = self.data
return unsharded_data

def unshard_(
Expand All @@ -212,7 +259,9 @@ def unshard_(
If ``rank0_only=True``, non rank 0 processes will have an empty tensor in their data.
"""
if unsharded_data is None:
unsharded_data = self._gather_data(dtype=dtype, rank0_only=rank0_only)
unsharded_data = (
self.data if not self.is_sharded else self._gather_data(dtype=dtype, rank0_only=rank0_only)
)
elif not rank0_only or get_rank(self.process_group) == 0:
unsharded_data = unsharded_data.view(*self.unsharded_shape)
self._set_metadata(self.SHARDED_FLAT_TENSOR_CACHED_SHARDED_DATA_KEY, self.data)
Expand Down Expand Up @@ -289,10 +338,13 @@ def chunk_unsharded(self, tensor: torch.Tensor, pad: bool = False) -> List[torch
chunks = []
flat_tensor = tensor.flatten()
max_size = max(self.sharding_spec.sharded_numels)
for offsets in self.sharding_spec.unsharded_flattened_offsets:
chunk = flat_tensor[offsets[0] : offsets[1]]
for rank_offsets in self.sharding_spec.unsharded_flattened_offsets:
rank_chunks = []
for start_idx, end_idx in rank_offsets:
rank_chunks.append(flat_tensor[start_idx:end_idx])
chunk = rank_chunks[0] if len(rank_chunks) == 1 else torch.cat(rank_chunks)
if pad:
chunk = F.pad(chunk, (0, max_size - (offsets[1] - offsets[0])))
chunk = F.pad(chunk, (0, max_size - chunk.numel()))
chunks.append(chunk)
return chunks

Expand All @@ -302,8 +354,11 @@ def sharded_chunk(self, tensor: torch.Tensor) -> torch.Tensor:
"""
if tensor.shape != self.unsharded_shape:
raise ValueError(f"shape mismatched, expected {self.unsharded_shape}, got {tensor.shape}")
offset_start, offset_end = self.unsharded_flattened_offsets
return tensor.flatten()[offset_start:offset_end]
flat_tensor = tensor.flatten()
rank_chunks = []
for start_idx, end_idx in self.unsharded_flattened_offsets:
rank_chunks.append(flat_tensor[start_idx:end_idx])
return rank_chunks[0] if len(rank_chunks) == 1 else torch.cat(rank_chunks)

@property
def is_sharded(self) -> bool:
Expand Down Expand Up @@ -337,12 +392,8 @@ def process_group(self) -> Optional[dist.ProcessGroup]:
)

@property
def unsharded_flattened_offsets(self) -> Tuple[int, int]:
# mypy is really bad some times
offsets: Tuple[int, int] = self.sharding_spec.unsharded_flattened_offsets[ # type: ignore[assignment]
get_rank(group=self.process_group)
]
return offsets
def unsharded_flattened_offsets(self) -> Tuple[Tuple[int, int], ...]:
return self.sharding_spec.unsharded_flattened_offsets[get_rank(group=self.process_group)]

@property
def unsharded_numel(self) -> int:
Expand Down
6 changes: 3 additions & 3 deletions src/test/distributed/checkpoint_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,13 @@ def save_and_load_checkpoint_with_different_sharding_spec(dir):
[
# save_tensor: |x x x|x x x
# load_tensor: |x x|x x x x
(((0, 3), (3, 6)), ((0, 2), (2, 6))),
((((0, 3),), ((3, 6),)), (((0, 2),), ((2, 6),))),
# save_tensor: |x x x|x x x
# load_tensor: |x x x x|x x
(((0, 3), (3, 6)), ((0, 4), (4, 6))),
((((0, 3),), ((3, 6),)), (((0, 4),), ((4, 6),))),
# save_tensor: |x x x x x x|
# load_tensor: |x x x x|x x
(((0, 6), (6, 6)), ((0, 4), (4, 6))),
((((0, 6),), ((6, 6),)), (((0, 4),), ((4, 6),))),
]
):
checkpointer = Checkpointer()
Expand Down
8 changes: 4 additions & 4 deletions src/test/distributed/tensors/sharded_flat_parameter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ def shard_and_gather(init_device: torch.device):
unsharded_shape = (2, 3)

for unsharded_flattened_offsets in [
((0, 3), (3, 6)), # balanced sharding
((0, 2), (2, 6)), # unbalanced sharding
((2, 6), (0, 2)), # unordered, unbalanced sharding
((0, 6), (6, 6)), # some ranks empty
(((0, 3),), ((3, 6),)), # balanced sharding
(((0, 2),), ((2, 6),)), # unbalanced sharding
(((2, 6),), ((0, 2),)), # unordered, unbalanced sharding
(((0, 6),), ((6, 6),)), # some ranks empty
None, # let ShardedFlatParameter decide
]:
tensor = torch.rand(*unsharded_shape, device=init_device)
Expand Down
9 changes: 5 additions & 4 deletions src/test/distributed/tensors/sharded_flat_tensor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,11 @@ def shard_and_gather(init_device: torch.device):
unsharded_shape = (2, 3)

for unsharded_flattened_offsets in [
((0, 3), (3, 6)), # balanced sharding
((0, 2), (2, 6)), # unbalanced sharding
((2, 6), (0, 2)), # unordered, unbalanced sharding
((0, 6), (6, 6)), # some ranks empty
(((0, 3),), ((3, 6),)), # balanced sharding
(((0, 2),), ((2, 6),)), # unbalanced sharding
(((2, 6),), ((0, 2),)), # unordered, unbalanced sharding
(((0, 6),), ((6, 6),)), # some ranks empty
(((0, 2), (4, 6)), ((2, 4),)), # more than one chunk on a rank
None, # let ShardedFlatTensor decide
]:
tensor = torch.rand(*unsharded_shape, device=init_device)
Expand Down

0 comments on commit ae42293

Please sign in to comment.