Skip to content

Commit

Permalink
[feat] don't download if already downloaded (#60)
Browse files Browse the repository at this point in the history
  • Loading branch information
tscholak authored Nov 21, 2024
1 parent d37557e commit 436d8d2
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 31 deletions.
4 changes: 0 additions & 4 deletions docs/quick-start.md
Original file line number Diff line number Diff line change
Expand Up @@ -332,8 +332,6 @@ Create a configuration file for the dataset preparation. Copy the following cont

tokenizer:
path: /mnt/inputs/SmolLM2-135M/tokenizer.json

remove_downloads: false
```

=== "Llama-3.2-1B"
Expand All @@ -351,8 +349,6 @@ Create a configuration file for the dataset preparation. Copy the following cont

tokenizer:
path: /mnt/inputs/Llama-3.2-1B/tokenizer.json

remove_downloads: false
```

and save it as `prepare-config.yaml` in your inputs folder.
Expand Down
5 changes: 0 additions & 5 deletions fast_llm/data/preparator/gpt_memmap/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,6 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig):
hint=FieldHint.optional,
valid=check_field(Assert.geq, 1),
)
remove_downloads: bool = Field(
default=False,
desc="Remove downloaded dataset after processing.",
hint=FieldHint.optional,
)
dataset: GPTHuggingfaceDatasetConfig = Field(
default_factory=GPTHuggingfaceDatasetConfig,
desc="Configuration for the dataset.",
Expand Down
55 changes: 33 additions & 22 deletions fast_llm/data/preparator/gpt_memmap/prepare.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import multiprocessing
import pathlib
import shutil

import datasets
Expand Down Expand Up @@ -49,6 +50,17 @@ def _save_shard(self, args) -> dict:
}
return dataset_dict

def _load_dataset(self) -> datasets.Dataset:
dataset = datasets.load_dataset(
path=self._config.dataset.path,
name=self._config.dataset.config_name,
split=self._config.dataset.split,
num_proc=self._config.loading_workers,
trust_remote_code=self._config.dataset.trust_remote_code,
)
assert isinstance(dataset, datasets.Dataset)
return dataset

def run(self):
# Set transformers logging verbosity
transformers.logging.set_verbosity_error()
Expand Down Expand Up @@ -86,25 +98,28 @@ def run(self):
# Prepare output directory
self._config.output_path.mkdir(parents=True, exist_ok=True)

# Download dataset if necessary on rank 0
download_path = self._config.output_path / "downloaded_dataset"
download_path_ok = download_path / "ok"
if self._config.distributed.rank == 0 and not download_path_ok.exists():
datasets.load_dataset(
path=self._config.dataset.path,
name=self._config.dataset.config_name,
split=self._config.dataset.split,
num_proc=self._config.loading_workers,
trust_remote_code=self._config.dataset.trust_remote_code,
).save_to_disk(download_path, num_proc=self._config.saving_workers)
download_path_ok.touch()

# Synchronize processes to wait for the download to finish
if self._config.distributed.world_size > 1:
torch.distributed.barrier()
if pathlib.Path(self._config.dataset.path).is_dir():
# Dataset is already downloaded, load from disk
dataset = self._load_dataset()
else:
# Dataset is not downloaded, download on rank 0
if self._config.distributed.rank == 0:
dataset = self._load_dataset()

# Synchronize processes to wait for the download to finish on rank 0
if self._config.distributed.world_size > 1:
torch.distributed.barrier()

# Load the downloaded dataset on remaining ranks
if self._config.distributed.rank != 0:
dataset = self._load_dataset()

# Load and shard the dataset on each rank
dataset = datasets.load_from_disk(download_path).shard(
# Synchronize processes to wait for the dataset to load on remaining ranks
if self._config.distributed.world_size > 1:
torch.distributed.barrier()

assert isinstance(dataset, datasets.Dataset)
dataset = dataset.shard(
num_shards=self._config.distributed.world_size,
index=self._config.distributed.rank,
)
Expand Down Expand Up @@ -154,7 +169,3 @@ def run(self):
if self._config.distributed.world_size > 1:
torch.distributed.barrier()
torch.distributed.destroy_process_group()

# Clean up downloaded dataset
if self._config.remove_downloads and self._config.distributed.rank == 0:
shutil.rmtree(download_path, ignore_errors=True)

0 comments on commit 436d8d2

Please sign in to comment.