diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index e05c029665..b72e0a2b38 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -350,7 +350,15 @@ def for_d_in_datasets(dataset_configs): split=None, ) else: - ds = load_from_disk(config_dataset.path) + try: + ds = load_from_disk(config_dataset.path) + except FileNotFoundError: + ds = load_dataset( + config_dataset.path, + name=config_dataset.name, + streaming=False, + split=None, + ) elif local_path.is_file(): ds_type = get_ds_type(config_dataset) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index a57b6d83e2..e87f19cc7f 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -371,44 +371,79 @@ def test_load_hub_with_revision_with_dpo(self): def test_load_local_hub_with_revision(self): """Verify that a local copy of a hub dataset can be loaded with a specific revision""" with tempfile.TemporaryDirectory() as tmp_dir: - with tempfile.TemporaryDirectory() as tmp_dir2: - tmp_ds_path = Path(tmp_dir2) / "mhenrichsen/alpaca_2k_test" - tmp_ds_path.mkdir(parents=True, exist_ok=True) - snapshot_download( - repo_id="mhenrichsen/alpaca_2k_test", - repo_type="dataset", - local_dir=tmp_ds_path, - revision="d05c1cb", - ) - - prepared_path = Path(tmp_dir) / "prepared" - cfg = DictDefault( - { - "tokenizer_config": "huggyllama/llama-7b", - "sequence_len": 1024, - "datasets": [ - { - "path": "mhenrichsen/alpaca_2k_test", - "ds_type": "parquet", - "type": "alpaca", - "data_files": [ - f"{tmp_ds_path}/alpaca_2000.parquet", - ], - "revision": "d05c1cb", - }, - ], - } - ) + tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test" + tmp_ds_path.mkdir(parents=True, exist_ok=True) + snapshot_download( + repo_id="mhenrichsen/alpaca_2k_test", + repo_type="dataset", + local_dir=tmp_ds_path, + revision="d05c1cb", + ) + + prepared_path = Path(tmp_dir) / "prepared" + cfg = DictDefault( + { + "tokenizer_config": "huggyllama/llama-7b", + "sequence_len": 1024, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "ds_type": "parquet", + "type": "alpaca", + "data_files": [ + f"{tmp_ds_path}/alpaca_2000.parquet", + ], + "revision": "d05c1cb", + }, + ], + } + ) - dataset, _ = load_tokenized_prepared_datasets( - self.tokenizer, cfg, prepared_path - ) + dataset, _ = load_tokenized_prepared_datasets( + self.tokenizer, cfg, prepared_path + ) - assert len(dataset) == 2000 - assert "input_ids" in dataset.features - assert "attention_mask" in dataset.features - assert "labels" in dataset.features - shutil.rmtree(tmp_ds_path) + assert len(dataset) == 2000 + assert "input_ids" in dataset.features + assert "attention_mask" in dataset.features + assert "labels" in dataset.features + shutil.rmtree(tmp_ds_path) + + def test_loading_local_dataset_folder(self): + """Verify that a dataset downloaded to a local folder can be loaded""" + + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test" + tmp_ds_path.mkdir(parents=True, exist_ok=True) + snapshot_download( + repo_id="mhenrichsen/alpaca_2k_test", + repo_type="dataset", + local_dir=tmp_ds_path, + ) + + prepared_path = Path(tmp_dir) / "prepared" + cfg = DictDefault( + { + "tokenizer_config": "huggyllama/llama-7b", + "sequence_len": 1024, + "datasets": [ + { + "path": str(tmp_ds_path), + "type": "alpaca", + }, + ], + } + ) + + dataset, _ = load_tokenized_prepared_datasets( + self.tokenizer, cfg, prepared_path + ) + + assert len(dataset) == 2000 + assert "input_ids" in dataset.features + assert "attention_mask" in dataset.features + assert "labels" in dataset.features + shutil.rmtree(tmp_ds_path) if __name__ == "__main__":