Skip to content

Commit

Permalink
Fix a regression with checkpointing optimizer state (#20)
Browse files Browse the repository at this point in the history
* Fix a regression with checkpointing optimizer state

* remove slot

* clean up
  • Loading branch information
epwalsh authored May 15, 2024
1 parent 7a2299a commit b2c3c09
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 9 deletions.
9 changes: 8 additions & 1 deletion src/olmo_core/distributed/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -1326,9 +1326,16 @@ def _get_local_tensor_data(tensor: torch.Tensor) -> torch.Tensor:


def _wrap_tensor_for_sharded_parameter(tensor: torch.Tensor, param: Optional[torch.Tensor]) -> torch.Tensor:
if isinstance(tensor, (ShardedFlatTensor, DTensor)):
if isinstance(tensor, DTensor):
return tensor

# TODO: (fixme) when you call `torch.empty_like(x)` on a `ShardedFlatTensor`, `x`, you get
# a `ShardedFlatTensor` without the metadata. Since PyTorch optimizer's use `torch.empty_like()`
# on each param to initialize its state, we run into an issue unless we still call `ShardedFlatTensor.wrap()`
# below.
# if isinstance(tensor, ShardedFlatTensor):
# return tensor

if isinstance(param, ShardedFlatTensor):
return param.wrap(tensor, requires_grad=False)
elif isinstance(param, DTensor):
Expand Down
3 changes: 2 additions & 1 deletion src/olmo_core/distributed/tensors/sharded_flat_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class ShardedFlatParameter(ShardedFlatTensor, nn.Parameter):
A :class:`~torch.nn.parameter.Parameter` version of :class:`ShardedFlatTensor`.
"""

@staticmethod
def __new__(cls, data: Optional[torch.Tensor] = None, requires_grad: bool = True) -> ShardedFlatParameter:
if data is not None and data.ndim != 1:
raise ValueError(f"{cls.__name__} requires flat data! Got {data.shape}")
Expand All @@ -29,7 +30,7 @@ def __new__(cls, data: Optional[torch.Tensor] = None, requires_grad: bool = True
setattr(
param,
cls.SHARDED_FLAT_TENSOR_METADATA_NAME,
getattr(data, cls.SHARDED_FLAT_TENSOR_METADATA_NAME).copy(),
getattr(data, cls.SHARDED_FLAT_TENSOR_METADATA_NAME, {}).copy(),
)
else:
setattr(param, cls.SHARDED_FLAT_TENSOR_METADATA_NAME, {})
Expand Down
18 changes: 11 additions & 7 deletions src/olmo_core/distributed/tensors/sharded_flat_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ class ShardedFlatTensor(torch.Tensor):
SHARDED_FLAT_TENSOR_PROCESS_GROUP_KEY = "process_group"
SHARDED_FLAT_TENSOR_CACHED_SHARDED_DATA_KEY = "sharded_data"

@staticmethod
def __new__(cls, data: torch.Tensor, requires_grad: bool = False) -> ShardedFlatTensor:
if data.ndim != 1:
raise ValueError(f"{cls.__name__} requires flat data! Got {data.shape}")
Expand Down Expand Up @@ -362,18 +363,21 @@ def sharded_chunk(self, tensor: torch.Tensor) -> torch.Tensor:

@property
def is_sharded(self) -> bool:
metadata = getattr(self, self.SHARDED_FLAT_TENSOR_METADATA_NAME)
return (
self.SHARDED_FLAT_TENSOR_SHARDING_SPEC_KEY in metadata
and self.SHARDED_FLAT_TENSOR_CACHED_SHARDED_DATA_KEY not in metadata
)
try:
metadata = getattr(self, self.SHARDED_FLAT_TENSOR_METADATA_NAME)
return (
self.SHARDED_FLAT_TENSOR_SHARDING_SPEC_KEY in metadata
and self.SHARDED_FLAT_TENSOR_CACHED_SHARDED_DATA_KEY not in metadata
)
except AttributeError:
return False

@property
def sharding_spec(self) -> ShardingSpec:
metadata = getattr(self, self.SHARDED_FLAT_TENSOR_METADATA_NAME)
try:
metadata = getattr(self, self.SHARDED_FLAT_TENSOR_METADATA_NAME)
return metadata[self.SHARDED_FLAT_TENSOR_SHARDING_SPEC_KEY]
except KeyError:
except (KeyError, AttributeError):
raise ValueError(
f"{self.__class__.__name__} has not been marked as sharded yet, "
"did you forget to class '.mark_as_sharded()'?"
Expand Down

0 comments on commit b2c3c09

Please sign in to comment.