From 411d92d44dc5d02354502a5b74f969f7e6c06484 Mon Sep 17 00:00:00 2001 From: Saaketh Date: Fri, 15 Sep 2023 15:57:18 -0700 Subject: [PATCH] updated StreamingTextDataset and StreamingFinetuningDataset with new streaming args, bumped streaming version --- llmfoundry/data/finetuning/dataloader.py | 8 ++ llmfoundry/data/finetuning/tasks.py | 104 ++++++++++++++++------- llmfoundry/data/text_data.py | 13 ++- setup.py | 2 +- 4 files changed, 94 insertions(+), 33 deletions(-) diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py index 6b1562c37f..661b1e808d 100644 --- a/llmfoundry/data/finetuning/dataloader.py +++ b/llmfoundry/data/finetuning/dataloader.py @@ -125,11 +125,19 @@ def build_finetuning_dataloader(cfg: DictConfig, download_timeout=cfg.dataset.get('download_timeout', 60), validate_hash=cfg.dataset.get('validate_hash', None), keep_zip=cfg.dataset.get('keep_zip', False), + epoch_size=cfg.dataset.get('epoch_size', None), predownload=cfg.dataset.get('predownload', None), + cache_limit=cfg.dataset.get('cache_limit', None), + partition_algo=cfg.dataset.get('partition_algo', 'orig'), num_canonical_nodes=cfg.dataset.get('num_canonical_nodes', None), batch_size=device_batch_size, shuffle=cfg.dataset.get('shuffle', False), + shuffle_algo=cfg.dataset.get('shuffle_algo', 'py1b'), shuffle_seed=cfg.dataset.get('shuffle_seed', 9176), + shuffle_block_size=cfg.dataset.get('shuffle_block_size', 1 << 18), + sampling_method=cfg.dataset.get('sampling_method', 'balanced'), + sampling_granularity=cfg.dataset.get('sampling_granularity', 1), + batching_method=cfg.dataset.get('batching_method', 'random'), ) collate_fn, dataloader_batch_size = _build_collate_fn( diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index c184dc9848..0a2b386048 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -71,44 +71,76 @@ class StreamingFinetuningDataset(StreamingDataset): """Finetuning dataset with flexible tokenization using StreamingDataset. Args: - local (str): Local dataset directory where shards are cached by split. tokenizer (Tokenizer): The name of the HuggingFace tokenizer to use to tokenize samples. - remote (str, optional): Download shards from this remote path or directory. If None, this - rank and worker's partition of the dataset must all exist locally. Defaults to ``None``. - split (str, optional): Which dataset split to use, if any. Defaults to ``None``. - shuffle (bool): Whether to iterate over the samples in randomized order. Defaults to ``False``. - predownload (int, optional): Target number of samples ahead to download the shards of while - iterating. Defaults to ``100_000``. - keep_zip (bool, optional): Whether to keep or delete the compressed file when - decompressing downloaded shards. If set to None, keep if remote is local. Defaults to - ``None``. + local (str): Local dataset directory where shards are cached by split. + remote (str, optional): Remote path or directory to download the dataset from. If ``None``, + its data must exist locally. StreamingDataset uses either ``streams`` or + ``remote``/``local``. Defaults to ``None``. + split (str, optional): Which dataset split to use, if any. If provided, we stream from/to + the ``split`` subdirs of ``remote`` and ``local``. Defaults to ``None``. download_retry (int): Number of download re-attempts before giving up. Defaults to ``2``. download_timeout (float): Number of seconds to wait for a shard to download before raising an exception. Defaults to ``60``. validate_hash (str, optional): Optional hash or checksum algorithm to use to validate shards. Defaults to ``None``. - shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``. - num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with resumption. - If ``None``, defaults to the number of nodes of the initial run. Defaults to 128. + keep_zip (bool): Whether to keep or delete the compressed form when decompressing + downloaded shards. If ``False``, keep iff remote is local or no remote. Defaults to + `False``. + epoch_size (int, optional): Number of samples to draw per epoch balanced across all + streams. If ``None``, takes its value from the total number of underlying samples. + Provide this field if you are weighting streams relatively to target a larger or + smaller epoch size. Defaults to ``None``. + predownload (int, optional): Target number of samples ahead to download the shards of while + iterating. Defaults to ``100_000``. + cache_limit (Union[int, str], optional) - Maximum size in bytes of this StreamingDataset's + shard cache. Before downloading a shard, the least recently used resident shard(s) may + be evicted (deleted from the local cache) in order to stay under the limit. Set to None + to disable shard eviction. Supports integer bytes as well as string human-readable + bytes (e.g., 100b, 64kb, 77mb, and so on). Defaults to None. + partition_algo (str): Which partitioning algorithm to use. Defaults to ``orig``. + num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with + resumption. Defaults to ``None``, which is interpreted as the number of nodes of the + initial run. batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is partitioned over the workers. Defaults to ``None``. + shuffle (bool): Whether to iterate over the samples in randomized order. Defaults to + ``False``. + shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1b``. + shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``. + shuffle_block_size (int): Unit of shuffle. Defaults to ``1 << 18``. + sampling_method (str): Which sampling method to use, either ``balanced`` or ``fixed``. + Defaults to ``balanced``. + sampling_granularity (int): When picking samples for a stream's final partial repeat, + how many samples to pick from the same shard at a time (``1`` for evenly balanced + across shards, ``1000`` to pick 1000 samples from the same shard at a time, etc). + Defaults to ``1``. + batching_method (str): Which batching method to use, either ``random``, ``stratified``, or + ``per_stream``. Defaults to ``random``. """ def __init__(self, - local: str, tokenizer: PreTrainedTokenizerBase, + local: str, remote: Optional[str] = None, split: Optional[str] = None, - shuffle: bool = False, - predownload: Optional[int] = 100_000, - keep_zip: bool = False, download_retry: int = 2, download_timeout: float = 60, validate_hash: Optional[str] = None, - shuffle_seed: int = 9176, - num_canonical_nodes: Optional[int] = 128, + keep_zip: bool = False, + epoch_size: Optional[int] = None, + predownload: Optional[int] = None, + cache_limit: Optional[Union[int, str]] = None, + partition_algo: str = 'orig', + num_canonical_nodes: Optional[int] = None, batch_size: Optional[int] = None, + shuffle: bool = False, + shuffle_algo: str = 'py1b', + shuffle_seed: int = 9176, + shuffle_block_size: int = 1 << 18, + sampling_method: str = 'balanced', + sampling_granularity: int = 1, + batching_method: str = 'random', **kwargs: Any): if len(kwargs) > 0: @@ -125,18 +157,28 @@ def __init__(self, ) # Build Dataset - super().__init__(local=local, - remote=remote, - split=split, - shuffle=shuffle, - predownload=predownload, - keep_zip=keep_zip, - download_retry=download_retry, - download_timeout=download_timeout, - validate_hash=validate_hash, - shuffle_seed=shuffle_seed, - num_canonical_nodes=num_canonical_nodes, - batch_size=batch_size) + super().__init__( + local=local, + remote=remote, + split=split, + download_retry=download_retry, + download_timeout=download_timeout, + validate_hash=validate_hash, + keep_zip=keep_zip, + epoch_size=epoch_size, + predownload=predownload, + cache_limit=cache_limit, + partition_algo=partition_algo, + num_canonical_nodes=num_canonical_nodes, + batch_size=batch_size, + shuffle=shuffle, + shuffle_algo=shuffle_algo, + shuffle_seed=shuffle_seed, + shuffle_block_size=shuffle_block_size, + sampling_method=sampling_method, + sampling_granularity=sampling_granularity, + batching_method=batching_method, + ) self.tokenizer = tokenizer diff --git a/llmfoundry/data/text_data.py b/llmfoundry/data/text_data.py index 31626b237f..afdd243adf 100644 --- a/llmfoundry/data/text_data.py +++ b/llmfoundry/data/text_data.py @@ -66,7 +66,14 @@ class StreamingTextDataset(StreamingDataset): shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1b``. shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``. shuffle_block_size (int): Unit of shuffle. Defaults to ``1 << 18``. - sampling_method (str): Which sampling method to use, either ``balanced`` or ``fixed``. Defaults to ``balanced``. + sampling_method (str): Which sampling method to use, either ``balanced`` or ``fixed``. + Defaults to ``balanced``. + sampling_granularity (int): When picking samples for a stream's final partial repeat, + how many samples to pick from the same shard at a time (``1`` for evenly balanced + across shards, ``1000`` to pick 1000 samples from the same shard at a time, etc). + Defaults to ``1``. + batching_method (str): Which batching method to use, either ``random``, ``stratified``, or + ``per_stream``. Defaults to ``random``. """ def __init__(self, @@ -91,6 +98,8 @@ def __init__(self, shuffle_seed: int = 9176, shuffle_block_size: int = 1 << 18, sampling_method: str = 'balanced', + sampling_granularity: int = 1, + batching_method: str = 'random', **kwargs: Any): group_method = kwargs.pop('group_method', None) @@ -138,6 +147,8 @@ def __init__(self, shuffle_seed=shuffle_seed, shuffle_block_size=shuffle_block_size, sampling_method=sampling_method, + sampling_granularity=sampling_granularity, + batching_method=batching_method, ) self.tokenizer = tokenizer self.max_seq_len = max_seq_len diff --git a/setup.py b/setup.py index b07b8afe08..1a93bd05f7 100644 --- a/setup.py +++ b/setup.py @@ -50,7 +50,7 @@ 'mosaicml[libcloud,wandb,mlflow]>=0.16.1,<0.17', 'accelerate>=0.20,<0.21', # for HF inference `device_map` 'transformers>=4.33,<4.34', - 'mosaicml-streaming>=0.5.1,<0.6', + 'mosaicml-streaming>=0.6,<0.7', 'torch>=1.13.1,<2.1.1', 'datasets>=2.14.5,<2.15', 'fsspec==2023.6.0', # newer version results in a bug in datasets that duplicates data