diff --git a/.dockerignore b/.dockerignore index 2022ee39..0ed5480a 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1,4 +1,7 @@ +# Ignore everything by default * + +# Allow specific files and directories !setup.py !setup.cfg !Megatron-LM @@ -7,3 +10,7 @@ !tools !tests !pyproject.toml + +# Exclude Python cache directories and shared object files within included directories +**/__pycache__/ +**/*.so diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 8629c06b..7b3accb2 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -57,12 +57,9 @@ jobs: ghcr.io/servicenow/fast-llm tags: | type=schedule - type=ref,event=branch - type=semver,pattern={{version}} - type=semver,pattern={{major}}.{{minor}} - type=semver,pattern={{major}} + type=pep440,pattern={{version}} type=sha - type=raw,value=latest,enabled={{github.ref == 'refs/heads/main'}} + type=raw,value=latest,enable={{is_default_branch}} - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 @@ -78,7 +75,6 @@ jobs: uses: docker/build-push-action@v6 with: context: . - # push: ${{ github.event_name != 'pull_request' }} push: true tags: ${{ steps.meta.outputs.tags }} labels: ${{ steps.meta.outputs.labels }} diff --git a/Dockerfile b/Dockerfile index 9c3ecf49..b3b3cf13 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,34 +1,39 @@ # syntax=docker/dockerfile:1.7-labs FROM nvcr.io/nvidia/pytorch:24.07-py3 -# Install git-lfs for Huggingface hub interaction and sudo for system adjustments +# Install dependencies. RUN apt-get update \ - && apt-get install --no-install-recommends -y git-lfs sudo util-linux \ + && apt-get install --no-install-recommends -y acl git-lfs \ && rm -rf /var/lib/apt/lists/* \ && git lfs install -# Add a user for Fast-LLM with sudo privileges for runtime adjustments -ARG FAST_LLM_USER_ID=1000 -RUN useradd -m -u $FAST_LLM_USER_ID -s /bin/bash fast_llm \ - && echo 'fast_llm ALL=(ALL) NOPASSWD: ALL' >> /etc/sudoers - -USER fast_llm +# Set the working directory. WORKDIR /app +# Set the permission to 777 for all files and directories in `/app`, `/home` and python install directories: +# 1. Create directories explicitly because docker use the wrong permission for explicit creation. +# 2. For the rest, set the default ACL to 777 for all users. +RUN mkdir -m 777 /app/Megatron-LM /app/examples /app/fast_llm /app/tests /app/tools \ + && setfacl -m d:u::rwx,d:g::rwx,d:o::rwx,u::rwx,g::rwx,o::rwx \ + /app \ + /home \ + /usr \ + /usr/local \ + /usr/local/bin \ + /usr/local/lib \ + /usr/local/lib/python3.10 \ + /usr/local/lib/python3.10/dist-packages \ + /usr/local/lib/python3.10/dist-packages/__pycache__ -# Environment settings for Python and PATH -ENV PYTHONPATH=/app:/app/Megatron-LM \ - PATH=$PATH:/home/fast_llm/.local/bin/ - -# Copy the dependency files and install dependencies -COPY --chown=fast_llm setup.py setup.cfg pyproject.toml ./ -COPY --chown=fast_llm ./fast_llm/csrc/ fast_llm/csrc/ -RUN PIP_NO_INPUT=1 pip3 install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,DEV]" +# Copy dependency files with universal write permissions for all users. +COPY --chmod=777 setup.py setup.cfg pyproject.toml ./ +COPY --chmod=777 ./fast_llm/csrc/ fast_llm/csrc/ -# Copy the rest of the code -COPY --chown=fast_llm ./Megatron-LM Megatron-LM -COPY --chown=fast_llm ./examples examples -COPY --chown=fast_llm ./tests tests -COPY --chown=fast_llm ./tools tools +# Install dependencies within the virtual environment. +RUN pip install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,DEV]" -# Copy the main source code for Fast-LLM -COPY --exclude=./fast_llm/csrc/ --chown=fast_llm ./fast_llm/ fast_llm/ +# Copy the remaining source code with universal write permissions. +COPY --chmod=777 ./Megatron-LM Megatron-LM +COPY --chmod=777 ./examples examples +COPY --chmod=777 ./tests tests +COPY --chmod=777 ./tools tools +COPY --chmod=777 --exclude=./fast_llm/csrc/ ./fast_llm/ fast_llm/ diff --git a/fast_llm/config.py b/fast_llm/config.py index 66d9c530..a7b237c0 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -301,7 +301,10 @@ def __setattr__(self, key, value): # Allow setting the exact same object to facilitate setup of cross-dependencies. # Ex. allow re-setting cross-dependencies of already validated sub-configs. return - raise RuntimeError() + raise RuntimeError( + f"Cannot set attribute `{key}`" + f" in configuration class `{get_type_name(type(self))}` after validation." + ) super().__setattr__(key, value) def __delattr__(self, key): @@ -309,7 +312,10 @@ def __delattr__(self, key): Make the class read-only after validation. """ if getattr(self, "_validated", False): - raise RuntimeError() + raise RuntimeError( + f"Cannot delete attribute `{key}`" + f" in configuration class `{get_type_name(type(self))}` after validation." + ) super().__delattr__(key) def validate(self, *, _is_validating=False): diff --git a/fast_llm/data/auto.py b/fast_llm/data/auto.py new file mode 100644 index 00000000..1a5b7986 --- /dev/null +++ b/fast_llm/data/auto.py @@ -0,0 +1,12 @@ +from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig +from fast_llm.utils import Registry + +dataset_preparator_registry = Registry( + "DatasetPreparator", + { + dataset_preparator.preparator_name: dataset_preparator + for dataset_preparator in [ + GPTMemmapDatasetPreparatorConfig, + ] + }, +) diff --git a/fast_llm/data/config.py b/fast_llm/data/config.py index 7487265a..59476eb4 100644 --- a/fast_llm/data/config.py +++ b/fast_llm/data/config.py @@ -107,7 +107,7 @@ def _validate(self): class TokenizerConfig(Config): """ Configuration for the tokenizer. - Currently, the tokenizer is only needed for FIM. + The tokenizer is needed for FIM and dataset preparation. """ format: str = Field( diff --git a/fast_llm/data/gpt/memmap.py b/fast_llm/data/gpt/memmap.py index b49bb9a5..a2a57271 100644 --- a/fast_llm/data/gpt/memmap.py +++ b/fast_llm/data/gpt/memmap.py @@ -4,6 +4,8 @@ import numpy as np from fast_llm.data.gpt.dataset import GPTIndexedDataset +from fast_llm.data.preparator.gpt_memmap.config import MEMMAP_DTYPES, MEMMAP_DTYPES_INV, MEMMAP_INDEX_HEADER +from fast_llm.engine.config_utils.data_type import DataType from fast_llm.utils import Assert, div, padded_cumsum @@ -16,18 +18,6 @@ class GPTMemmapDataset(GPTIndexedDataset): See https://github.com/NVIDIA/Megatron-LM?tab=readme-ov-file#data-preprocessing for more details. """ - _DTYPES = { - 1: np.uint8, - 2: np.int8, - 3: np.int16, - 4: np.int32, - 5: np.int64, - 6: np.float32, - 7: np.float64, - 8: np.uint16, - } - _INDEX_HEADER = b"MMIDIDX\x00\x00" - def __init__(self, name: str, prefix: pathlib.Path | str): self._init(name, prefix) @@ -37,10 +27,10 @@ def _init(self, name: str, prefix: pathlib.Path | str): self._prefix = pathlib.Path(prefix) with self._prefix.with_suffix(".idx").open("rb") as stream: - Assert.eq(stream.read(9), self._INDEX_HEADER) + Assert.eq(stream.read(9), MEMMAP_INDEX_HEADER) Assert.eq(struct.unpack(" type["DatasetPreparator"]: + raise NotImplementedError + + def _get_runnable(self, parsed: argparse.Namespace) -> typing.Callable[[], None]: + dataset_preparator = self.get_dataset_preparator_class()(config=self) + return dataset_preparator.run + + +class DatasetPreparator(abc.ABC): + _config: DatasetPreparatorConfig + config_class: typing.ClassVar[type[DatasetPreparatorConfig]] = DatasetPreparatorConfig + + def __init__(self, config: DatasetPreparatorConfig) -> None: + Assert.custom(isinstance, config, self.config_class) + config.validate() + self._config = config + + @abc.abstractmethod + def run(self) -> None: + raise NotImplementedError diff --git a/fast_llm/data/preparator/gpt_memmap/__init__.py b/fast_llm/data/preparator/gpt_memmap/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py new file mode 100644 index 00000000..9188a14c --- /dev/null +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -0,0 +1,162 @@ +import os +import pathlib +import typing + +from fast_llm.config import Config, Field, FieldHint, check_field, config_class +from fast_llm.data.config import TokenizerConfig +from fast_llm.data.preparator.config import DatasetPreparatorConfig +from fast_llm.engine.config_utils.data_type import DataType +from fast_llm.utils import Assert + +MEMMAP_DTYPES = { + 1: DataType.uint8, + 2: DataType.int8, + 3: DataType.int16, + 4: DataType.int32, + 5: DataType.int64, + 6: DataType.float32, + 7: DataType.float64, + 8: DataType.uint16, +} +MEMMAP_DTYPES_INV = {y: x for x, y in MEMMAP_DTYPES.items()} +MEMMAP_INDEX_HEADER = b"MMIDIDX\x00\x00" + + +@config_class +class GPTHuggingfaceDatasetConfig(Config): + path: str = Field( + default=None, + desc="Name or path of the dataset.", + hint=FieldHint.core, + ) + config_name: None | str = Field( + default=None, + desc="Specific configuration name for the dataset.", + hint=FieldHint.optional, + ) + split: str = Field( + default="train", + desc="Split of the dataset to use.", + hint=FieldHint.optional, + ) + field: str = Field( + default="text", + desc="Field of the dataset to use.", + hint=FieldHint.optional, + ) + data_type: DataType | None = Field( + default=None, + desc="Data type of the dataset field." + " If not provided, it will be inferred based on the tokenizer vocabulary size.", + hint=FieldHint.optional, + ) + trust_remote_code: bool = Field( + default=False, + desc="Trust remote code when downloading the dataset.", + hint=FieldHint.optional, + ) + disable_disk_space_check: bool = Field( + default=False, + desc="Disable disk space check. Useful for environments where disk space is not accurately reported.", + hint=FieldHint.optional, + ) + + +@config_class +class DatasetPreparatorDistributedConfig(Config): + # TODO: Unify with fast_llm.engine.distributed.config.DistributedConfig + + default_world_size: typing.ClassVar[int] = int(os.environ.get("WORLD_SIZE", 1)) + default_rank: typing.ClassVar[int] = int(os.environ.get("RANK", 0)) + world_size: int = Field( + default=None, + desc="Size of the world group. Typically provided by torchrun or equivalent through the `WORLD_SIZE` environment variable.", + hint=FieldHint.expert, + valid=check_field(Assert.gt, 0), + ) + rank: int = Field( + default=None, + desc="Rank of the local process. Typically provided by torchrun or equivalent through the `RANK` environment variable.", + hint=FieldHint.expert, + valid=check_field(Assert.geq, 0), + ) + backend: str = Field( + default="gloo", + desc="Distributed backend to use.", + hint=FieldHint.optional, + ) + + def _validate(self): + if self.world_size is None: + self.world_size = self.default_world_size + if self.rank is None: + self.rank = self.default_rank + super()._validate() + Assert.in_range(self.rank, 0, self.world_size) + + +@config_class() +class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): + preparator_name: typing.ClassVar[str] = "gpt_memmap" + + output_path: pathlib.Path = Field( + default=None, + desc="Output directory for the processed dataset.", + hint=FieldHint.core, + ) + distributed: DatasetPreparatorDistributedConfig = Field( + default_factory=DatasetPreparatorDistributedConfig, + desc="Configuration for distributed processing.", + hint=FieldHint.feature, + ) + tokens_per_shard: int = Field( + default=10**9, + desc="Approximate number of tokens per shard.", + hint=FieldHint.feature, + valid=check_field(Assert.geq, 10**5), + ) + loading_workers: int = Field( + default=1, + desc="Number of workers in load_dataset() call.", + hint=FieldHint.optional, + valid=check_field(Assert.geq, 1), + ) + tokenize_workers: int = Field( + default=1, + desc="Number of workers for tokenization.", + hint=FieldHint.optional, + valid=check_field(Assert.geq, 1), + ) + saving_workers: int = Field( + default=1, + desc="Number of processes for saving the data.", + hint=FieldHint.optional, + valid=check_field(Assert.geq, 1), + ) + remove_downloads: bool = Field( + default=False, + desc="Remove downloaded dataset after processing.", + hint=FieldHint.optional, + ) + dataset: GPTHuggingfaceDatasetConfig = Field( + default_factory=GPTHuggingfaceDatasetConfig, + desc="Configuration for the dataset.", + hint=FieldHint.feature, + ) + tokenizer: TokenizerConfig = Field( + default_factory=TokenizerConfig, + desc="Configuration for the tokenizer.", + hint=FieldHint.feature, + ) + + def _validate(self): + assert self.tokenizer.path is not None + if self.dataset.data_type is not None: + Assert.incl(DataType.from_numpy(self.dataset.data_type.numpy), MEMMAP_DTYPES_INV) + super()._validate() + + @classmethod + def get_dataset_preparator_class(cls): + from fast_llm.data.preparator.gpt_memmap.prepare import GPTMemmapDatasetPreparator + + return GPTMemmapDatasetPreparator diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py new file mode 100644 index 00000000..c51bd4a7 --- /dev/null +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -0,0 +1,162 @@ +import json +import multiprocessing +import shutil + +import datasets +import numpy as np +import torch.distributed +import tqdm +import transformers + +from fast_llm.data.gpt.memmap import GPTMemmapDataset +from fast_llm.data.preparator.config import DatasetPreparator +from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig +from fast_llm.data.tokenizer import Tokenizer +from fast_llm.engine.config_utils.data_type import DataType + + +class GPTMemmapDatasetPreparator(DatasetPreparator): + _config: GPTMemmapDatasetPreparatorConfig + config_class = GPTMemmapDatasetPreparatorConfig + + _tokenizer: Tokenizer + _data_type: DataType + + def _tokenize_batch(self, batch): + input_ids = [ + np.array(self._tokenizer.tokenize(text), dtype=self._data_type.numpy) + for text in batch[self._config.dataset.field] + ] + num_tokens = [len(x) for x in input_ids] + return { + "input_ids": input_ids, + "num_tokens": num_tokens, + } + + def _save_shard(self, args) -> dict: + + shard_idx, shard_dataset = args + prefix = f"shard_{self._config.distributed.rank}_{shard_idx}" + shard_output_path = self._config.output_path / prefix + documents = [ + np.array(item["input_ids"], dtype=self._data_type.numpy) + for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs") + ] + GPTMemmapDataset.write_dataset(prefix=shard_output_path, documents=documents) + dataset_dict = { + "prefix": prefix, + "num_documents": len(documents), + "num_tokens": sum(len(doc) for doc in documents), + } + return dataset_dict + + def run(self): + + # Set transformers logging verbosity + transformers.logging.set_verbosity_error() + + # Disable disk space check if requested + if self._config.dataset.disable_disk_space_check: + datasets.builder.has_sufficient_disk_space = lambda needed_bytes, directory=".": True + + # Load tokenizer + self._tokenizer = Tokenizer(config=self._config.tokenizer) + + # Set data type if not provided + if self._config.dataset.data_type is None: + # Decide the datatype based on the tokenizer vocabulary size + vocab_size = self._tokenizer.vocab_size + if vocab_size <= np.iinfo(np.int16).max: + self._data_type = DataType.int16 + # elif vocab_size <= np.iinfo(np.uint16).max: + # self._data_type = DataType.uint16 # Not supported by Fast-LLM's DataType + elif vocab_size <= np.iinfo(np.int32).max: + self._data_type = DataType.int32 + else: + raise ValueError(f"Tokenizer vocabulary size {vocab_size} is too large. This is likely an error.") + else: + self._data_type = self._config.dataset.data_type + + # Initialize distributed processing + if self._config.distributed.world_size > 1: + torch.distributed.init_process_group( + backend=self._config.distributed.backend, + rank=self._config.distributed.rank, + world_size=self._config.distributed.world_size, + ) + + # Prepare output directory + self._config.output_path.mkdir(parents=True, exist_ok=True) + + # Download dataset if necessary on rank 0 + download_path = self._config.output_path / "downloaded_dataset" + download_path_ok = download_path / "ok" + if self._config.distributed.rank == 0 and not download_path_ok.exists(): + datasets.load_dataset( + path=self._config.dataset.path, + name=self._config.dataset.config_name, + split=self._config.dataset.split, + num_proc=self._config.loading_workers, + trust_remote_code=self._config.dataset.trust_remote_code, + ).save_to_disk(download_path, num_proc=self._config.saving_workers) + download_path_ok.touch() + + # Synchronize processes to wait for the download to finish + if self._config.distributed.world_size > 1: + torch.distributed.barrier() + + # Load and shard the dataset on each rank + dataset = datasets.load_from_disk(download_path).shard( + num_shards=self._config.distributed.world_size, + index=self._config.distributed.rank, + ) + if self._config.dataset.field not in dataset.column_names: + raise ValueError(f"Dataset does not have field '{self._config.dataset.field}'.") + + # Tokenize the dataset in parallel + tokenized_dataset = dataset.map( + self._tokenize_batch, + batched=True, + num_proc=self._config.tokenize_workers, + desc="Tokenizing batches", + ) + + # Calculate total number of tokens + total_tokens = sum(tqdm.tqdm(tokenized_dataset["num_tokens"], desc="Counting tokens", unit="tokens")) + + # Split dataset into shards based on number of tokens + num_shards = int(np.ceil(total_tokens / self._config.tokens_per_shard)) + shards = [ + (i, tokenized_dataset.shard(num_shards=num_shards, index=i)) + for i in tqdm.tqdm(range(num_shards), desc="Creating shards") + ] + + # Use multiprocessing to save each shard in parallel on all ranks + with multiprocessing.Pool(processes=self._config.saving_workers) as pool: + dataset_dicts = pool.map(self._save_shard, shards) + + # Gather dataset_dicts from all ranks to rank 0 + if self._config.distributed.world_size > 1: + if self._config.distributed.rank == 0: + all_dataset_dicts = [None] * self._config.distributed.world_size + torch.distributed.gather_object(dataset_dicts, all_dataset_dicts, dst=0) + dataset_dicts = [item for sublist in all_dataset_dicts for item in sublist] + else: + torch.distributed.gather_object(dataset_dicts, [], dst=0) + + # Create a metadata file on rank 0 + if self._config.distributed.rank == 0: + total_tokens = sum(dataset_dict["num_tokens"] for dataset_dict in dataset_dicts) + for dataset_dict in dataset_dicts: + dataset_dict["weight"] = float(dataset_dict["num_tokens"]) / float(total_tokens) + output_file = self._config.output_path / "fast_llm_dataset.json" + json.dump({"datasets": dataset_dicts}, output_file.open("w")) + + # Finalize distributed processing + if self._config.distributed.world_size > 1: + torch.distributed.barrier() + torch.distributed.destroy_process_group() + + # Clean up downloaded dataset + if self._config.remove_downloads and self._config.distributed.rank == 0: + shutil.rmtree(download_path) diff --git a/fast_llm/data/tokenizer.py b/fast_llm/data/tokenizer.py index d75aab7f..2061d6b6 100644 --- a/fast_llm/data/tokenizer.py +++ b/fast_llm/data/tokenizer.py @@ -6,7 +6,7 @@ class Tokenizer: """ - A Huggingface (transformers) tokenizer. + A wrapper around Huggingface (transformers) tokenizer. """ def __init__(self, config: TokenizerConfig): diff --git a/fast_llm/engine/config_utils/data_type.py b/fast_llm/engine/config_utils/data_type.py index 25aa1ea4..f4ae9307 100644 --- a/fast_llm/engine/config_utils/data_type.py +++ b/fast_llm/engine/config_utils/data_type.py @@ -25,6 +25,7 @@ class DataType(str, enum.Enum): int16 = "int16" int8 = "int8" uint8 = "uint8" + uint16 = "uint16" @classmethod def _missing_(cls, dtype: str) -> "DataType": @@ -128,8 +129,9 @@ def _set_numpy_dtype_map(): DataType.int16: np.int16, DataType.int8: np.int8, DataType.uint8: np.uint8, + DataType.uint16: np.uint16, } - _TORCH_DTYPE_MAP_INV = {y: x for x, y in _NUMPY_DTYPE_MAP.items()} + _NUMPY_DTYPE_MAP_INV = {y: x for x, y in _NUMPY_DTYPE_MAP.items()} _TRITON_DTYPE_MAP: dict[DataType, "tl.dtype"] = {} diff --git a/fast_llm/tools/cli.py b/fast_llm/tools/cli.py index 7b338953..e9df18ed 100644 --- a/fast_llm/tools/cli.py +++ b/fast_llm/tools/cli.py @@ -15,13 +15,15 @@ def fast_llm(args=None): # (Pre-)configure logging configure_logging() parser = argparse.ArgumentParser(add_help=False) - parser.add_argument("subcommand", choices=["train", "convert"]) + parser.add_argument("subcommand", choices=["train", "convert", "prepare"]) parsed, unparsed = parser.parse_known_args(args) try: if parsed.subcommand == "train": from fast_llm.tools.train import CliTrainingConfig as Runnable elif parsed.subcommand == "convert": from fast_llm.tools.convert import ConversionConfig as Runnable + elif parsed.subcommand == "prepare": + from fast_llm.tools.prepare_dataset import PrepareDatasetConfig as Runnable else: raise RuntimeError("Unknown subcommand") Runnable.parse_and_run(unparsed) diff --git a/fast_llm/tools/prepare_dataset.py b/fast_llm/tools/prepare_dataset.py new file mode 100644 index 00000000..aafe2690 --- /dev/null +++ b/fast_llm/tools/prepare_dataset.py @@ -0,0 +1,24 @@ +import argparse + +from fast_llm.data.auto import dataset_preparator_registry +from fast_llm.engine.config_utils.runnable import RunnableConfig + + +class PrepareDatasetConfig(RunnableConfig): + @classmethod + def _get_parser(cls): + parser = super()._get_parser() + parser.add_argument( + "model_type", + choices=dataset_preparator_registry.keys(), + help="The Fast-LLM model type to use. Must be defined in the model registry in `fast_llm.models.auto`.", + ) + return parser + + @classmethod + def _from_parsed_args(cls, parsed: argparse.Namespace, unparsed: list[str]): + return dataset_preparator_registry[parsed.model_type]._from_parsed_args(parsed, unparsed) + + +if __name__ == "__main__": + PrepareDatasetConfig.parse_and_run() diff --git a/fast_llm/utils.py b/fast_llm/utils.py index 937aacd8..5ae1c5d0 100644 --- a/fast_llm/utils.py +++ b/fast_llm/utils.py @@ -175,8 +175,12 @@ def not_custom(fn, *args, **kwargs): ), f"Assertion failed: not fn({', '.join(itertools.chain((str(x) for x in args),(f'{str(k)}={str(v)}' for k,v in kwargs.items())))})" -class Registry: - def __init__(self, name, data: dict): +_KeyType = typing.TypeVar("_KeyType") +_ValueType = typing.TypeVar("_ValueType") + + +class Registry(typing.Generic[_KeyType, _ValueType]): + def __init__(self, name: str, data: dict[_KeyType, _ValueType]): self._name = name self._data = data.copy() diff --git a/setup.cfg b/setup.cfg index a353151c..51f87ac5 100644 --- a/setup.cfg +++ b/setup.cfg @@ -32,11 +32,14 @@ OPTIONAL = # Huggingface tools transformers>=4.44.2 hf-transfer>=0.1.8 + datasets>=3.1.0 # Weights and biases wandb>=0.17.7 # Hydra hydra-core>=1.3.2 omegaconf>=2.3.0 + # Miscellanous + tqdm>=4.66.3 # Required for testing DEV = diff --git a/tests/test_config.py b/tests/test_config.py index a9f2aeaf..c382aedb 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -6,7 +6,7 @@ from fast_llm.models.auto import trainer_registry -def test_validate_without_import(): +def run_without_import(cmd: str): # Make sure validation imports only the bare minimum. # Run the test in a separate process since lots of things are already imported in this one. repo_path = pathlib.Path(__file__).parents[1].resolve() @@ -22,7 +22,7 @@ def test_validate_without_import(): # We still want to enable imports from within Fast-llm f"sys.path.append('{repo_path}')", "from fast_llm.tools.cli import fast_llm as main", - "main(['train', 'gpt', '-v'])", + cmd, ] ), ] @@ -32,6 +32,16 @@ def test_validate_without_import(): raise RuntimeError(f"Process failed with return code {completed_proc.returncode}") +def test_validate_train_gpt_without_import(): + run_without_import("main(['train', 'gpt', '-v'])") + + +def test_validate_prepare_gpt_memmap_without_import(): + run_without_import( + "main(['prepare', 'gpt_memmap', '-v', 'dataset.path=test', 'output_path=test', 'tokenizer.path=test'])" + ) + + def test_validate_example_config(): fast_llm_config_dict = yaml.safe_load( (pathlib.Path(__file__).parents[1] / "examples" / "mistral-4-node-benchmark.yaml").read_text() diff --git a/tests/test_memmap_dataset.py b/tests/test_memmap_dataset.py new file mode 100644 index 00000000..413c2648 --- /dev/null +++ b/tests/test_memmap_dataset.py @@ -0,0 +1,21 @@ +import pathlib +import tempfile + +import numpy as np +import pytest + +from fast_llm.data.gpt.memmap import GPTMemmapDataset +from fast_llm.data.preparator.gpt_memmap.config import MEMMAP_DTYPES + + +@pytest.mark.parametrize("dtype", MEMMAP_DTYPES.values()) +def test_gpt_memmap_dataset(dtype): + documents = [np.random.randint(1000, size=np.random.randint(1, 100)).astype(dtype) for _ in range(100)] + with tempfile.TemporaryDirectory() as temp_dir: + prefix = pathlib.Path(temp_dir) + GPTMemmapDataset.write_dataset(prefix=prefix, documents=documents) + dataset = GPTMemmapDataset(name="foo", prefix=prefix) + for i, document in enumerate(documents): + assert np.array_equal( + dataset.get(i), document, equal_nan=True + ), f"Mismatch for document {i}: {document} != {dataset.get(i)}."