diff --git a/src/olmo_core/distributed/checkpoint.py b/src/olmo_core/distributed/checkpoint.py index b4f75ca7..16315553 100644 --- a/src/olmo_core/distributed/checkpoint.py +++ b/src/olmo_core/distributed/checkpoint.py @@ -360,15 +360,6 @@ def load( metadata = metadata or self.get_metadata(dir, no_dist=no_dist) safetensors_mfl = _safetensors_mfl or SafeTensorsMultiFileLoader() - def validate_shard_in_file(tensor: torch.Tensor, loader: SafeTensorsLoader, key: str, filename: str): - if len((shape_in_file := loader.get_shape(key))) != 1: - raise ValueError(f"Expected a 1D tensor at {key} in {filename}, found shape {shape_in_file}") - - if (dtype := loader.get_dtype(key)) != tensor.dtype: - raise ValueError( - f"Data type mismatch between tensor to load ({dtype}) and to load into ({tensor.dtype})" - ) - # Load each tensor from the slices in each file. for key in state_dict.keys(): log.debug("Loading tensor '%s' from state dict...", key) @@ -395,7 +386,15 @@ def validate_shard_in_file(tensor: torch.Tensor, loader: SafeTensorsLoader, key: continue # no overlap with data in file, so nothing to load with safetensors_mfl.open(f"{dir}/{filename}") as loader: - validate_shard_in_file(tensor, loader, key, filename) + # Validate the shard in the file. + if len((shape_in_file := loader.get_shape(key))) != 1: + raise ValueError( + f"Expected a 1D tensor at {key} in {filename}, found shape {shape_in_file}" + ) + if (dtype := loader.get_dtype(key)) != tensor.dtype: + raise ValueError( + f"Data type mismatch between tensor to load ({dtype}) and to load into ({tensor.dtype})" + ) if overlap == OverlapType.EQUAL: flat_view.view.copy_(loader.get_flat_slice(key)) @@ -826,7 +825,7 @@ def compute_overlap_with(self, other: TensorShardSpec, full_shape: Tuple[int, .. else: return OverlapType.MIXED - return None + return OverlapType.MIXED def _offsets_overlap(offsets: Tuple[int, int], other_offsets: Tuple[int, int]) -> bool: