diff --git a/llmfoundry/command_utils/data_prep/convert_text_to_mds.py b/llmfoundry/command_utils/data_prep/convert_text_to_mds.py index 39b7178219..668f42a679 100644 --- a/llmfoundry/command_utils/data_prep/convert_text_to_mds.py +++ b/llmfoundry/command_utils/data_prep/convert_text_to_mds.py @@ -89,11 +89,10 @@ def __iter__(self) -> Iterable[dict[str, NDArray]]: buffer += iids while len(buffer) >= self.max_length: concat_sample = buffer[:self.max_length] - buffer = buffer[self.max_length: - ] if self.should_wrap else [] + buffer = buffer[self. + max_length:] if self.should_wrap else [] yield { - 'tokens': - np.asarray(concat_sample, dtype=np.int32), + 'tokens': np.asarray(concat_sample, dtype=np.int32), } first_chunk = False