Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed May 10, 2024
1 parent 4f5299c commit 329fa00
Showing 1 changed file with 10 additions and 11 deletions.
21 changes: 10 additions & 11 deletions src/olmo_core/distributed/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,15 +360,6 @@ def load(
metadata = metadata or self.get_metadata(dir, no_dist=no_dist)
safetensors_mfl = _safetensors_mfl or SafeTensorsMultiFileLoader()

def validate_shard_in_file(tensor: torch.Tensor, loader: SafeTensorsLoader, key: str, filename: str):
if len((shape_in_file := loader.get_shape(key))) != 1:
raise ValueError(f"Expected a 1D tensor at {key} in {filename}, found shape {shape_in_file}")

if (dtype := loader.get_dtype(key)) != tensor.dtype:
raise ValueError(
f"Data type mismatch between tensor to load ({dtype}) and to load into ({tensor.dtype})"
)

# Load each tensor from the slices in each file.
for key in state_dict.keys():
log.debug("Loading tensor '%s' from state dict...", key)
Expand All @@ -395,7 +386,15 @@ def validate_shard_in_file(tensor: torch.Tensor, loader: SafeTensorsLoader, key:
continue # no overlap with data in file, so nothing to load

with safetensors_mfl.open(f"{dir}/{filename}") as loader:
validate_shard_in_file(tensor, loader, key, filename)
# Validate the shard in the file.
if len((shape_in_file := loader.get_shape(key))) != 1:
raise ValueError(
f"Expected a 1D tensor at {key} in {filename}, found shape {shape_in_file}"
)
if (dtype := loader.get_dtype(key)) != tensor.dtype:
raise ValueError(
f"Data type mismatch between tensor to load ({dtype}) and to load into ({tensor.dtype})"
)

if overlap == OverlapType.EQUAL:
flat_view.view.copy_(loader.get_flat_slice(key))
Expand Down Expand Up @@ -826,7 +825,7 @@ def compute_overlap_with(self, other: TensorShardSpec, full_shape: Tuple[int, ..
else:
return OverlapType.MIXED

return None
return OverlapType.MIXED


def _offsets_overlap(offsets: Tuple[int, int], other_offsets: Tuple[int, int]) -> bool:
Expand Down

0 comments on commit 329fa00

Please sign in to comment.