From 4cdcda07e0a83083f55e055a0f99886310c8f209 Mon Sep 17 00:00:00 2001 From: Matthew Ding Date: Tue, 10 Sep 2024 09:04:08 -0700 Subject: [PATCH 01/16] refactor hf download --- llmfoundry/data/finetuning/tasks.py | 115 +++++++++++++++++----------- 1 file changed, 71 insertions(+), 44 deletions(-) diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 915267786f..911d995f31 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -702,7 +702,73 @@ def state_dict(self, num_samples: int, num_samples=num_samples, from_beginning=from_beginning, ) + +def download_hf_dataset_if_needed( + dataset_name: str, + hf_kwargs: Optional[dict[str, Any]] = None +) -> str: + """ + Download a HuggingFace dataset locally if it does not already exist. + + Args: + dataset_name (str): The name of the HuggingFace dataset to use. Can be a remote http(s) + directory or object store bucket containing the file {split}.jsonl. + safe_load (bool): Whether to enforce safe loading of the dataset. + hf_kwargs (dict, optional): Additional kwargs to pass to `datasets.load_dataset`. + + Returns: + str: The local path to the dataset. + """ + if hf_kwargs is None: + hf_kwargs = {} + + if not os.path.isdir(dataset_name): + local_dataset_dir = os.path.join( + DOWNLOADED_FT_DATASETS_DIRPATH, + dataset_name, + ) + + if _is_empty_or_nonexistent(dirpath=local_dataset_dir): + # Safely load the dataset from HF Hub with restricted file types. + hf_hub.snapshot_download( + dataset_name, + repo_type='dataset', + allow_patterns=[ + '*' + ext for ext in SUPPORTED_EXTENSIONS + ], + token=hf_kwargs.get('token', None), + revision=hf_kwargs.get('revision', None), + local_dir_use_symlinks=False, + local_dir=local_dataset_dir, + ) + if _is_empty_or_nonexistent(dirpath=dataset_name): + log.error("Failed to safely load the dataset from HF Hub.") + raise InvalidFileExtensionError( + dataset_name, + SUPPORTED_EXTENSIONS, + ) + # Set dataset_name to the downloaded location. + dataset_name = local_dataset_dir + + # Ensure dataset_name is a local directory path (using abspath to avoid confusion). + dataset_name = os.path.abspath(dataset_name) + + # Check that the directory contains only allowed file types. + dataset_files = [ + f for _, _, files in os.walk(dataset_name) for f in files + ] + if not all( + Path(f).suffix in SUPPORTED_EXTENSIONS + + HUGGINGFACE_FOLDER_EXTENSIONS or f == '.gitignore' + for f in dataset_files + ): + log.error(f"Invalid file extension found in dataset during safe load.") + raise InvalidFileExtensionError( + dataset_name, + SUPPORTED_EXTENSIONS, + ) + return dataset_name class DatasetConstructor: @@ -901,50 +967,11 @@ def build_from_hf( filtered_dataset = None try: if safe_load: - if not os.path.isdir(dataset_name): - # dataset_name is not a local dir path, download if needed. - local_dataset_dir = os.path.join( - DOWNLOADED_FT_DATASETS_DIRPATH, - dataset_name, - ) - - if _is_empty_or_nonexistent(dirpath=local_dataset_dir): - # Safely load a dataset from HF Hub with restricted file types. - hf_hub.snapshot_download( - dataset_name, - repo_type='dataset', - allow_patterns=[ - '*' + ext for ext in SUPPORTED_EXTENSIONS - ], - token=hf_kwargs.get('token', None), - revision=hf_kwargs.get('revision', None), - local_dir_use_symlinks=False, - local_dir=local_dataset_dir, - ) - if _is_empty_or_nonexistent(dirpath=local_dataset_dir): - raise InvalidFileExtensionError( - dataset_name, - SUPPORTED_EXTENSIONS, - ) - # Set dataset_name to the downloaded location. - dataset_name = local_dataset_dir - - # dataset_name is a local dir path. Use the abspath to prevent confusion. - dataset_name = os.path.abspath(dataset_name) - - # Ensure that the local dir contains only allowed file types. - dataset_files = [ - f for _, _, files in os.walk(dataset_name) for f in files - ] - if not all( - Path(f).suffix in SUPPORTED_EXTENSIONS + - HUGGINGFACE_FOLDER_EXTENSIONS or f == '.gitignore' - for f in dataset_files - ): - raise InvalidFileExtensionError( - dataset_name, - SUPPORTED_EXTENSIONS, - ) + dataset_name = download_hf_dataset_if_needed( + dataset_name, + safe_load, + hf_kwargs, + ) dataset = hf_datasets.load_dataset( dataset_name, From a1385b4814930af0876af4c96873e703eb7f3c05 Mon Sep 17 00:00:00 2001 From: Matthew Ding Date: Thu, 12 Sep 2024 13:34:59 -0700 Subject: [PATCH 02/16] split_eval_set skeleton --- llmfoundry/command_utils/__init__.py | 36 +++++++------ .../command_utils/data_prep/split_eval_set.py | 37 +++++++++++++ scripts/data_prep/split_eval_set.py | 54 +++++++++++++++++++ 3 files changed, 110 insertions(+), 17 deletions(-) create mode 100644 llmfoundry/command_utils/data_prep/split_eval_set.py create mode 100644 scripts/data_prep/split_eval_set.py diff --git a/llmfoundry/command_utils/__init__.py b/llmfoundry/command_utils/__init__.py index 0226c4f408..5407b723cc 100644 --- a/llmfoundry/command_utils/__init__.py +++ b/llmfoundry/command_utils/__init__.py @@ -20,6 +20,7 @@ convert_text_to_mds, convert_text_to_mds_from_args, ) +from llmfoundry.command_utils.data_prep.split_eval_set import split_eval_set_from_args from llmfoundry.command_utils.eval import ( eval_from_yaml, evaluate, @@ -33,21 +34,22 @@ ) __all__ = [ - 'train', - 'train_from_yaml', - 'TrainConfig', - 'TRAIN_CONFIG_KEYS', - 'validate_config', - 'evaluate', - 'eval_from_yaml', - 'convert_dataset_hf', - 'convert_dataset_hf_from_args', - 'convert_dataset_json', - 'convert_dataset_json_from_args', - 'convert_finetuning_dataset_from_args', - 'convert_finetuning_dataset', - 'convert_text_to_mds', - 'convert_text_to_mds_from_args', - 'convert_delta_to_json_from_args', - 'fetch_DT', + "train", + "train_from_yaml", + "TrainConfig", + "TRAIN_CONFIG_KEYS", + "validate_config", + "evaluate", + "eval_from_yaml", + "convert_dataset_hf", + "convert_dataset_hf_from_args", + "convert_dataset_json", + "convert_dataset_json_from_args", + "convert_finetuning_dataset_from_args", + "convert_finetuning_dataset", + "convert_text_to_mds", + "convert_text_to_mds_from_args", + "convert_delta_to_json_from_args", + "fetch_DT", + "split_eval_set_from_args", ] diff --git a/llmfoundry/command_utils/data_prep/split_eval_set.py b/llmfoundry/command_utils/data_prep/split_eval_set.py new file mode 100644 index 0000000000..01205cba15 --- /dev/null +++ b/llmfoundry/command_utils/data_prep/split_eval_set.py @@ -0,0 +1,37 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import os +import json +from enum import Enum + +import datasets +from llmfoundry.data.finetuning.tasks import download_hf_dataset_if_needed + + +class SupportedDataFormats(Enum): + REMOTE_JSONL = "jsonl" # UC JSONL + DELTA_JSONL = "delta_jsonl" # Delta table preprocessed to JSONL + HF = "huggingface" + + +def validate_data_path(data_path: str) -> None: + """ + Validates the data path and returns the format of the data. + + Args: + data_path (str): Path to the training dataset + """ + + +def split_eval_set_from_args() -> None: + """ + Args: + data_path_folder (str): Path to the training dataset folder + data_path_split (str): Data split + output_path (str): Directory to save the split dataset + eval_split_ratio (float): Ratio of the dataset to use for evaluation. The remainder will be used for training + max_eval_samples (int): Maximum number of samples to include in the eval set. If None, all eval_split_ratio * train_dataset_size samples will be used + seed (int): Random seed for splitting the dataset + """ + pass diff --git a/scripts/data_prep/split_eval_set.py b/scripts/data_prep/split_eval_set.py new file mode 100644 index 0000000000..ee8bfee453 --- /dev/null +++ b/scripts/data_prep/split_eval_set.py @@ -0,0 +1,54 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from argparse import ArgumentParser + +from llmfoundry.command_utils import split_eval_set_from_args + + +if __name__ == "__main__": + parser = ArgumentParser( + description="Split training dataset into train and eval sets", + ) + parser.add_argument( + "--data_path_folder", required=True, type=str, help="Path to the training dataset folder" + ) + parser.add_argument( + "--data_path_split", required=True, type=str, help="Path to the training dataset split" + ) + parser.add_argument( + "--output_path", + required=True, + type=str, + help="Path to save the split dataset", + ) + parser.add_argument( + "--eval_split_ratio", + required=False, + type=float, + default=0.1, + help="Ratio of the dataset to use for evaluation. The remainder will be used for training", + ) + parser.add_argument( + "--max_eval_samples", + required=False, + type=int, + default=None, + help="Maximum number of samples to include in the eval set", + ) + parser.add_argument( + "--seed", + required=False, + type=int, + default=42, + help="Random seed for splitting the dataset", + ) + args = parser.parse_args() + split_eval_set_from_args( + data_path_folder=args.data_path_folder, + data_path_split=args.data_path_split, + output_path=args.output_path, + eval_split_ratio=args.eval_split_ratio, + max_eval_samples=args.max_eval_samples, + seed=args.seed, + ) From 87d7a4c9142a05d9e29abbba8e9c1889e09eb447 Mon Sep 17 00:00:00 2001 From: Matthew Ding Date: Sun, 15 Sep 2024 16:22:57 -0700 Subject: [PATCH 03/16] splitting script --- .../command_utils/data_prep/split_eval_set.py | 162 ++++++++++++++++-- llmfoundry/data/finetuning/tasks.py | 6 +- 2 files changed, 152 insertions(+), 16 deletions(-) diff --git a/llmfoundry/command_utils/data_prep/split_eval_set.py b/llmfoundry/command_utils/data_prep/split_eval_set.py index 01205cba15..f6afc8722d 100644 --- a/llmfoundry/command_utils/data_prep/split_eval_set.py +++ b/llmfoundry/command_utils/data_prep/split_eval_set.py @@ -1,31 +1,167 @@ # Copyright 2024 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +import logging import os +import re import json -from enum import Enum +import contextlib +import datasets as hf_datasets +import numpy as np +from typing import Optional -import datasets -from llmfoundry.data.finetuning.tasks import download_hf_dataset_if_needed +from composer.utils import get_file +from llmfoundry.data.finetuning.tasks import maybe_safe_download_hf_data -class SupportedDataFormats(Enum): - REMOTE_JSONL = "jsonl" # UC JSONL - DELTA_JSONL = "delta_jsonl" # Delta table preprocessed to JSONL - HF = "huggingface" +DELTA_JSONL_REGEX = re.compile(r"^tmp-t$") +REMOTE_OBJECT_STORE_FILE_REGEX = re.compile( + r"^((s3|oci|gs):\/\/|dbfs:\/Volumes\/)[/a-zA-Z0-9 ()_\-.]+$" +) +HF_REGEX = re.compile(r"^[/a-zA-Z0-9 ()_\-.]+$") +TEMP_DIR = "tmp-split" -def validate_data_path(data_path: str) -> None: +log = logging.getLogger(__name__) + +import sys + +log.setLevel(logging.DEBUG) +log.addHandler(logging.StreamHandler(sys.stdout)) + + +def maybe_download_data_as_json(data_path_folder: str, data_path_split: str) -> str: """ - Validates the data path and returns the format of the data. + Prepares dataset as a local JSONL file. Downloads from remote object store or HF if necessary. + + This function is intended to be invoked by DBX Finetuning. + Thus, it assumes the provided data is in one of three formats: + 1. A Delta table converted to JSONL at 'tmp-t/{data_path_split}-00000-of-00001.jsonl` + using the 'llmfoundry.scripts.convert_delta_to_json.py' script. + 2. A JSONL stored as a remote object store file (e.g. S3, OCI, GCS) + 3. A Hugging Face dataset Args: - data_path (str): Path to the training dataset + data_path_folder (str): Path to the training dataset folder + data_path_split (str): Data split + + Returns: + str: Path to the training dataset """ + os.makedirs(TEMP_DIR, exist_ok=True) + + if DELTA_JSONL_REGEX.match(data_path_folder): + data_path = os.path.join(data_path_folder, f"{data_path_split}-00000-of-00001.jsonl") + if not os.path.exists(data_path): + # TODO: error handling + raise FileNotFoundError(f"File {data_path} does not exist.") + + if REMOTE_OBJECT_STORE_FILE_REGEX.match(data_path_folder): + log.info( + f"Downloading dataset from remote object store: {data_path_folder}{data_path_split}.jsonl" + ) + remote_path = f"{data_path_folder}/{data_path_split}.jsonl" + data_path = os.path.join(TEMP_DIR, f"{data_path_split}.jsonl") + try: + get_file(remote_path, data_path, overwrite=True) + except FileNotFoundError as e: + # TODO: error handling + raise e + + elif HF_REGEX.match(data_path_folder): + log.info( + f"Downloading dataset from Hugging Face: {data_path_folder} with split {data_path_split}" + ) + # TODO: maybe add support for HF kwargs + local_hf_path = maybe_safe_download_hf_data(data_path_folder) + # convert dataset split to JSONL + dataset = hf_datasets.load_dataset( + local_hf_path, + split=data_path_split, + ) + data_path = os.path.join(TEMP_DIR, f"{data_path_split}.jsonl") + with open(data_path, "w") as f: + for example in dataset: + f.write(json.dumps(example) + "\n") + + else: + # TODO: error handling + raise ValueError( + f"Unrecognized data_path_folder: {data_path_folder}. Must be a Delta table, remote object store file, or Hugging Face dataset." + ) + + if not os.path.exists(data_path): + # TODO: error handling + raise FileNotFoundError(f"File {data_path} does not exist.") + + return data_path + +@contextlib.contextmanager +def temp_seed(seed: int): + state = np.random.get_state() + np.random.seed(seed) + try: + yield + finally: + np.random.set_state(state) -def split_eval_set_from_args() -> None: + +def _split_examples( + data_path: str, + output_path: str, + eval_split_ratio: float, + max_eval_samples: Optional[int], + seed: Optional[int] = None, +) -> None: + """ + Splits the dataset into training and evaluation sets. + + Args: + data_path (str): Path to the training dataset (local jsonl file) + eval_split_ratio (float): Ratio of the dataset to use for evaluation. The remainder will be used for training + max_eval_samples (int): Maximum number of samples to include in the eval set. If None, all eval_split_ratio * train_dataset_size samples will be used + seed (int): Random seed for splitting the dataset """ + # first pass: count total number of lines and determine sample size + total_lines = 0 + with open(data_path, "r") as infile: + for _ in infile: + total_lines += 1 + sample_size = int(eval_split_ratio * total_lines) + if max_eval_samples is not None: + sample_size = min(sample_size, max_eval_samples) + + with temp_seed(seed) if seed is not None else contextlib.nullcontext(): + random_numbers = np.random.rand(total_lines) + sample_indices = set(np.argsort(random_numbers)[:sample_size]) + + # second pass: sample indices + with open(data_path, "r") as infile, open( + os.path.join(output_path, "train.jsonl"), "w" + ) as train_outfile, open(os.path.join(output_path, "eval.jsonl"), "w") as eval_outfile: + for idx, line in enumerate(infile): + if idx in sample_indices: + eval_outfile.write(line) + else: + train_outfile.write(line) + + log.info( + f"Split {data_path} into train set of size {total_lines - sample_size} and eval set of size {sample_size}." + ) + + +def split_eval_set_from_args( + data_path_folder: str, + data_path_split: str, + output_path: str, + eval_split_ratio: float, + max_eval_samples: Optional[int] = None, + seed: Optional[int] = None, +) -> None: + """ + A wrapper for split_eval_set that parses arguments + Args: data_path_folder (str): Path to the training dataset folder data_path_split (str): Data split @@ -34,4 +170,6 @@ def split_eval_set_from_args() -> None: max_eval_samples (int): Maximum number of samples to include in the eval set. If None, all eval_split_ratio * train_dataset_size samples will be used seed (int): Random seed for splitting the dataset """ - pass + os.makedirs(output_path, exist_ok=True) + data_path = maybe_download_data_as_json(data_path_folder, data_path_split) + _split_examples(data_path, output_path, eval_split_ratio, max_eval_samples, seed) diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 911d995f31..dc89b31730 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -703,7 +703,7 @@ def state_dict(self, num_samples: int, from_beginning=from_beginning, ) -def download_hf_dataset_if_needed( +def maybe_safe_download_hf_data( dataset_name: str, hf_kwargs: Optional[dict[str, Any]] = None ) -> str: @@ -713,7 +713,6 @@ def download_hf_dataset_if_needed( Args: dataset_name (str): The name of the HuggingFace dataset to use. Can be a remote http(s) directory or object store bucket containing the file {split}.jsonl. - safe_load (bool): Whether to enforce safe loading of the dataset. hf_kwargs (dict, optional): Additional kwargs to pass to `datasets.load_dataset`. Returns: @@ -967,9 +966,8 @@ def build_from_hf( filtered_dataset = None try: if safe_load: - dataset_name = download_hf_dataset_if_needed( + dataset_name = maybe_download_hf_data( dataset_name, - safe_load, hf_kwargs, ) From b04651f3c817fdb774243fc741bd5838c5b76241 Mon Sep 17 00:00:00 2001 From: Matthew Ding Date: Mon, 16 Sep 2024 00:58:53 -0700 Subject: [PATCH 04/16] error handling and testing --- llmfoundry/command_utils/__init__.py | 6 +- .../command_utils/data_prep/split_eval_set.py | 38 ++-- .../data_prep/test_split_eval_set.py | 163 ++++++++++++++++++ 3 files changed, 183 insertions(+), 24 deletions(-) create mode 100644 tests/a_scripts/data_prep/test_split_eval_set.py diff --git a/llmfoundry/command_utils/__init__.py b/llmfoundry/command_utils/__init__.py index 5407b723cc..8757f3b1bc 100644 --- a/llmfoundry/command_utils/__init__.py +++ b/llmfoundry/command_utils/__init__.py @@ -20,7 +20,10 @@ convert_text_to_mds, convert_text_to_mds_from_args, ) -from llmfoundry.command_utils.data_prep.split_eval_set import split_eval_set_from_args +from llmfoundry.command_utils.data_prep.split_eval_set import ( + split_eval_set_from_args, + split_examples, +) from llmfoundry.command_utils.eval import ( eval_from_yaml, evaluate, @@ -52,4 +55,5 @@ "convert_delta_to_json_from_args", "fetch_DT", "split_eval_set_from_args", + "split_examples", ] diff --git a/llmfoundry/command_utils/data_prep/split_eval_set.py b/llmfoundry/command_utils/data_prep/split_eval_set.py index f6afc8722d..b4b150f81f 100644 --- a/llmfoundry/command_utils/data_prep/split_eval_set.py +++ b/llmfoundry/command_utils/data_prep/split_eval_set.py @@ -10,7 +10,7 @@ import numpy as np from typing import Optional -from composer.utils import get_file +import composer.utils as utils from llmfoundry.data.finetuning.tasks import maybe_safe_download_hf_data @@ -24,11 +24,6 @@ log = logging.getLogger(__name__) -import sys - -log.setLevel(logging.DEBUG) -log.addHandler(logging.StreamHandler(sys.stdout)) - def maybe_download_data_as_json(data_path_folder: str, data_path_split: str) -> str: """ @@ -51,22 +46,16 @@ def maybe_download_data_as_json(data_path_folder: str, data_path_split: str) -> os.makedirs(TEMP_DIR, exist_ok=True) if DELTA_JSONL_REGEX.match(data_path_folder): + log.info(f"Dataset is converted from Delta table. Using local file {data_path_folder}") data_path = os.path.join(data_path_folder, f"{data_path_split}-00000-of-00001.jsonl") - if not os.path.exists(data_path): - # TODO: error handling - raise FileNotFoundError(f"File {data_path} does not exist.") - if REMOTE_OBJECT_STORE_FILE_REGEX.match(data_path_folder): + elif REMOTE_OBJECT_STORE_FILE_REGEX.match(data_path_folder): log.info( f"Downloading dataset from remote object store: {data_path_folder}{data_path_split}.jsonl" ) remote_path = f"{data_path_folder}/{data_path_split}.jsonl" data_path = os.path.join(TEMP_DIR, f"{data_path_split}.jsonl") - try: - get_file(remote_path, data_path, overwrite=True) - except FileNotFoundError as e: - # TODO: error handling - raise e + utils.get_file(remote_path, data_path, overwrite=True) elif HF_REGEX.match(data_path_folder): log.info( @@ -85,20 +74,21 @@ def maybe_download_data_as_json(data_path_folder: str, data_path_split: str) -> f.write(json.dumps(example) + "\n") else: - # TODO: error handling raise ValueError( - f"Unrecognized data_path_folder: {data_path_folder}. Must be a Delta table, remote object store file, or Hugging Face dataset." + f"Encountered unknown data path format when splitting dataset: {data_path_folder} with split {data_path_split}" ) if not os.path.exists(data_path): - # TODO: error handling - raise FileNotFoundError(f"File {data_path} does not exist.") + raise FileNotFoundError( + f"Expected dataset file at {data_path} for splitting, but it does not exist." + ) return data_path @contextlib.contextmanager def temp_seed(seed: int): + log.info(f"Setting random seed to {seed}") state = np.random.get_state() np.random.seed(seed) try: @@ -107,11 +97,11 @@ def temp_seed(seed: int): np.random.set_state(state) -def _split_examples( +def split_examples( data_path: str, output_path: str, eval_split_ratio: float, - max_eval_samples: Optional[int], + max_eval_samples: Optional[int] = None, seed: Optional[int] = None, ) -> None: """ @@ -119,10 +109,13 @@ def _split_examples( Args: data_path (str): Path to the training dataset (local jsonl file) + output_path (str): Directory to save the split dataset eval_split_ratio (float): Ratio of the dataset to use for evaluation. The remainder will be used for training max_eval_samples (int): Maximum number of samples to include in the eval set. If None, all eval_split_ratio * train_dataset_size samples will be used seed (int): Random seed for splitting the dataset """ + os.makedirs(output_path, exist_ok=True) + # first pass: count total number of lines and determine sample size total_lines = 0 with open(data_path, "r") as infile: @@ -170,6 +163,5 @@ def split_eval_set_from_args( max_eval_samples (int): Maximum number of samples to include in the eval set. If None, all eval_split_ratio * train_dataset_size samples will be used seed (int): Random seed for splitting the dataset """ - os.makedirs(output_path, exist_ok=True) data_path = maybe_download_data_as_json(data_path_folder, data_path_split) - _split_examples(data_path, output_path, eval_split_ratio, max_eval_samples, seed) + split_examples(data_path, output_path, eval_split_ratio, max_eval_samples, seed) diff --git a/tests/a_scripts/data_prep/test_split_eval_set.py b/tests/a_scripts/data_prep/test_split_eval_set.py new file mode 100644 index 0000000000..a1b80b91cd --- /dev/null +++ b/tests/a_scripts/data_prep/test_split_eval_set.py @@ -0,0 +1,163 @@ +import os +import json +import pytest +import hashlib +from unittest.mock import patch + +from llmfoundry.command_utils import split_eval_set_from_args, split_examples + +# Default values +OUTPUT_DIR = "tmp-split" +TMPT_DIR = "tmp-t" +DATA_PATH_SPLIT = "train" +EVAL_SPLIT_RATIO = 0.1 +DEFAULT_FILE = TMPT_DIR + "/train-00000-of-00001.jsonl" + + +def calculate_file_hash(filepath: str) -> str: + with open(filepath, "rb") as f: + file_hash = hashlib.sha256(f.read()).hexdigest() + return file_hash + + +def count_lines(filepath: str) -> int: + with open(filepath, "r") as f: + return sum(1 for _ in f) + + +@pytest.fixture(scope="module", autouse=True) +def setup_and_teardown_module(): + # Setup: create local testing file + os.makedirs(TMPT_DIR, exist_ok=True) + with open(DEFAULT_FILE, "w") as f: + for i in range(1000): + f.write(json.dumps({"prompt": "hello world " + str(i), "response": "hi you!"}) + "\n") + yield + + # Teardown: clean up output and tmp directories + os.system(f"rm -rf {OUTPUT_DIR}") + os.system(f"rm -rf {TMPT_DIR}") + + +def test_basic_split(): + """Test basic functionality on local file""" + output_path = os.path.join(OUTPUT_DIR, "basic-test") + split_eval_set_from_args(TMPT_DIR, DATA_PATH_SPLIT, output_path, EVAL_SPLIT_RATIO) + assert os.path.isfile(os.path.join(output_path, "train.jsonl")) + assert os.path.isfile(os.path.join(output_path, "eval.jsonl")) + + +def test_basic_split_output_exists(): + """Test that split overwrites existing files in output directory""" + output_path = os.path.join(OUTPUT_DIR, "basic-test") + os.makedirs(output_path, exist_ok=True) + train_file = os.path.join(output_path, "train.jsonl") + eval_file = os.path.join(output_path, "eval.jsonl") + with open(train_file, "w") as f: + f.write("existing file train") + with open(eval_file, "w") as f: + f.write("existing file eval") + old_train_hash = calculate_file_hash(train_file) + old_eval_hash = calculate_file_hash(eval_file) + split_eval_set_from_args( + TMPT_DIR, + DATA_PATH_SPLIT, + output_path, + EVAL_SPLIT_RATIO, + ) + assert calculate_file_hash(train_file) != old_train_hash + assert calculate_file_hash(eval_file) != old_eval_hash + + +def test_max_eval_samples(): + """Test case where max_eval_samples < eval_split_ratio * total samples""" + output_path = os.path.join(OUTPUT_DIR, "max-eval-test") + max_eval_samples = 50 + split_eval_set_from_args( + TMPT_DIR, + DATA_PATH_SPLIT, + output_path, + EVAL_SPLIT_RATIO, + max_eval_samples, + ) + eval_lines = count_lines(os.path.join(output_path, "eval.jsonl")) + assert eval_lines == max_eval_samples + + +def test_eval_split_ratio(): + """Test case where max_eval_samples is not used""" + output_path = os.path.join(OUTPUT_DIR, "eval-split-test") + split_eval_set_from_args(TMPT_DIR, DATA_PATH_SPLIT, output_path, EVAL_SPLIT_RATIO) + original_data_lines = count_lines(DEFAULT_FILE) + eval_lines = count_lines(os.path.join(output_path, "eval.jsonl")) + assert abs(eval_lines - EVAL_SPLIT_RATIO * original_data_lines) < 1 # allow for rounding errors + + +def test_seed_consistency(): + """Test if the same seed generates consistent splits""" + output_path_1 = os.path.join(OUTPUT_DIR, "seed-test-1") + output_path_2 = os.path.join(OUTPUT_DIR, "seed-test-2") + split_examples(DEFAULT_FILE, output_path_1, EVAL_SPLIT_RATIO, seed=12345) + split_examples(DEFAULT_FILE, output_path_2, EVAL_SPLIT_RATIO, seed=12345) + train_hash_1 = calculate_file_hash(os.path.join(output_path_1, "train.jsonl")) + train_hash_2 = calculate_file_hash(os.path.join(output_path_2, "train.jsonl")) + eval_hash_1 = calculate_file_hash(os.path.join(output_path_1, "eval.jsonl")) + eval_hash_2 = calculate_file_hash(os.path.join(output_path_2, "eval.jsonl")) + + assert train_hash_1 == train_hash_2 + assert eval_hash_1 == eval_hash_2 + + output_path_3 = os.path.join(OUTPUT_DIR, "seed-test-3") + split_examples(DEFAULT_FILE, output_path_3, EVAL_SPLIT_RATIO, seed=54321) + train_hash_3 = calculate_file_hash(os.path.join(output_path_3, "train.jsonl")) + eval_hash_3 = calculate_file_hash(os.path.join(output_path_3, "eval.jsonl")) + + assert train_hash_1 != train_hash_3 + assert eval_hash_1 != eval_hash_3 + + +def test_hf_data_split(): + """Test splitting a dataset from Hugging Face""" + output_path = os.path.join(OUTPUT_DIR, "hf-split-test") + split_eval_set_from_args( + "databricks/databricks-dolly-15k", "train", output_path, EVAL_SPLIT_RATIO + ) + assert os.path.isfile(os.path.join(output_path, "train.jsonl")) + assert os.path.isfile(os.path.join(output_path, "eval.jsonl")) + assert count_lines(os.path.join(output_path, "train.jsonl")) > 0 + assert count_lines(os.path.join(output_path, "eval.jsonl")) > 0 + + +def _mock_get_file(remote_path: str, data_path: str, overwrite: bool): + with open(data_path, "w") as f: + for i in range(1000): + f.write(json.dumps({"prompt": "hello world " + str(i), "response": "hi you!"}) + "\n") + + +def test_remote_store_data_split(): + """Test splitting a dataset from a remote store""" + output_path = os.path.join(OUTPUT_DIR, "remote-split-test") + with patch("composer.utils.get_file", side_effect=_mock_get_file) as mock_get_file: + split_eval_set_from_args( + "dbfs:/Volumes/test/test/test.jsonl", + "unique-split-name", + output_path, + EVAL_SPLIT_RATIO, + ) + mock_get_file.assert_called() + + assert os.path.isfile(os.path.join(output_path, "train.jsonl")) + assert os.path.isfile(os.path.join(output_path, "eval.jsonl")) + assert count_lines(os.path.join(output_path, "train.jsonl")) > 0 + assert count_lines(os.path.join(output_path, "eval.jsonl")) > 0 + + +def test_missing_delta_file_error(): + # expects file 'TMPT_DIR/missing-00000-of-00001.jsonl + with pytest.raises(FileNotFoundError): + split_eval_set_from_args(TMPT_DIR, "missing", OUTPUT_DIR, EVAL_SPLIT_RATIO) + + +def test_unknown_file_format_error(): + with pytest.raises(ValueError): + split_eval_set_from_args("s3:/path/to/file.jsonl", "train", OUTPUT_DIR, EVAL_SPLIT_RATIO) From cc42fe4825d71d884a63f5ac2b508ac9b9c2f0e8 Mon Sep 17 00:00:00 2001 From: Matthew Ding Date: Mon, 16 Sep 2024 01:08:53 -0700 Subject: [PATCH 05/16] undo autoformat --- llmfoundry/command_utils/__init__.py | 38 ++++++++++++++-------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/llmfoundry/command_utils/__init__.py b/llmfoundry/command_utils/__init__.py index 8757f3b1bc..4f74fe6ec9 100644 --- a/llmfoundry/command_utils/__init__.py +++ b/llmfoundry/command_utils/__init__.py @@ -37,23 +37,23 @@ ) __all__ = [ - "train", - "train_from_yaml", - "TrainConfig", - "TRAIN_CONFIG_KEYS", - "validate_config", - "evaluate", - "eval_from_yaml", - "convert_dataset_hf", - "convert_dataset_hf_from_args", - "convert_dataset_json", - "convert_dataset_json_from_args", - "convert_finetuning_dataset_from_args", - "convert_finetuning_dataset", - "convert_text_to_mds", - "convert_text_to_mds_from_args", - "convert_delta_to_json_from_args", - "fetch_DT", - "split_eval_set_from_args", - "split_examples", + 'train', + 'train_from_yaml', + 'TrainConfig', + 'TRAIN_CONFIG_KEYS', + 'validate_config', + 'evaluate', + 'eval_from_yaml', + 'convert_dataset_hf', + 'convert_dataset_hf_from_args', + 'convert_dataset_json', + 'convert_dataset_json_from_args', + 'convert_finetuning_dataset_from_args', + 'convert_finetuning_dataset', + 'convert_text_to_mds', + 'convert_text_to_mds_from_args', + 'convert_delta_to_json_from_args', + 'fetch_DT', + 'split_eval_set_from_args', + 'split_examples', ] From 41eeb496cd86d544f71f522cd4ae5f3092538e81 Mon Sep 17 00:00:00 2001 From: Jimmy Xu Date: Fri, 25 Oct 2024 12:02:01 -0400 Subject: [PATCH 06/16] add regex tests Signed-off-by: Jimmy Xu --- .../command_utils/data_prep/split_eval_set.py | 28 ++++++++-- scripts/data_prep/split_eval_set.py | 5 +- .../data_prep/test_split_eval_set.py | 54 +++++++++++++++++++ 3 files changed, 80 insertions(+), 7 deletions(-) diff --git a/llmfoundry/command_utils/data_prep/split_eval_set.py b/llmfoundry/command_utils/data_prep/split_eval_set.py index b4b150f81f..c376f20184 100644 --- a/llmfoundry/command_utils/data_prep/split_eval_set.py +++ b/llmfoundry/command_utils/data_prep/split_eval_set.py @@ -13,6 +13,7 @@ import composer.utils as utils from llmfoundry.data.finetuning.tasks import maybe_safe_download_hf_data +log = logging.getLogger(__name__) DELTA_JSONL_REGEX = re.compile(r"^tmp-t$") REMOTE_OBJECT_STORE_FILE_REGEX = re.compile( @@ -20,10 +21,25 @@ ) HF_REGEX = re.compile(r"^[/a-zA-Z0-9 ()_\-.]+$") -TEMP_DIR = "tmp-split" +def get_dataset_format(data_path_folder: str) -> str: + """ + Determine the format of the dataset from the provided data path -log = logging.getLogger(__name__) + Args: + data_path_folder (str): Path to the training dataset folder + Returns: + str: The format of the dataset + """ + if DELTA_JSONL_REGEX.match(data_path_folder): + return "delta" + if REMOTE_OBJECT_STORE_FILE_REGEX.match(data_path_folder): + return "remote_object_store" + if HF_REGEX.match(data_path_folder): + return "hugging_face" + return "unknown" + +TEMP_DIR = "tmp-split" def maybe_download_data_as_json(data_path_folder: str, data_path_split: str) -> str: """ @@ -45,11 +61,13 @@ def maybe_download_data_as_json(data_path_folder: str, data_path_split: str) -> """ os.makedirs(TEMP_DIR, exist_ok=True) - if DELTA_JSONL_REGEX.match(data_path_folder): + dataset_format = get_dataset_format(data_path_folder) + + if dataset_format == "delta": log.info(f"Dataset is converted from Delta table. Using local file {data_path_folder}") data_path = os.path.join(data_path_folder, f"{data_path_split}-00000-of-00001.jsonl") - elif REMOTE_OBJECT_STORE_FILE_REGEX.match(data_path_folder): + elif dataset_format == "remote_object_store": log.info( f"Downloading dataset from remote object store: {data_path_folder}{data_path_split}.jsonl" ) @@ -57,7 +75,7 @@ def maybe_download_data_as_json(data_path_folder: str, data_path_split: str) -> data_path = os.path.join(TEMP_DIR, f"{data_path_split}.jsonl") utils.get_file(remote_path, data_path, overwrite=True) - elif HF_REGEX.match(data_path_folder): + elif dataset_format == "hugging_face": log.info( f"Downloading dataset from Hugging Face: {data_path_folder} with split {data_path_split}" ) diff --git a/scripts/data_prep/split_eval_set.py b/scripts/data_prep/split_eval_set.py index ee8bfee453..8631fceb4f 100644 --- a/scripts/data_prep/split_eval_set.py +++ b/scripts/data_prep/split_eval_set.py @@ -18,8 +18,9 @@ ) parser.add_argument( "--output_path", - required=True, + required=False, type=str, + default="/tmp-split", help="Path to save the split dataset", ) parser.add_argument( @@ -33,7 +34,7 @@ "--max_eval_samples", required=False, type=int, - default=None, + default=100, help="Maximum number of samples to include in the eval set", ) parser.add_argument( diff --git a/tests/a_scripts/data_prep/test_split_eval_set.py b/tests/a_scripts/data_prep/test_split_eval_set.py index a1b80b91cd..7596debfc1 100644 --- a/tests/a_scripts/data_prep/test_split_eval_set.py +++ b/tests/a_scripts/data_prep/test_split_eval_set.py @@ -5,6 +5,7 @@ from unittest.mock import patch from llmfoundry.command_utils import split_eval_set_from_args, split_examples +from llmfoundry.command_utils.data_prep.split_eval_set import DELTA_JSONL_REGEX, get_dataset_format, HF_REGEX, REMOTE_OBJECT_STORE_FILE_REGEX # Default values OUTPUT_DIR = "tmp-split" @@ -14,6 +15,59 @@ DEFAULT_FILE = TMPT_DIR + "/train-00000-of-00001.jsonl" +def test_delta_jsonl_regex(): + """Test the regex pattern matches tmp-t exactly""" + assert DELTA_JSONL_REGEX.match("tmp-t") + assert not DELTA_JSONL_REGEX.match("/tmp-t") + assert not DELTA_JSONL_REGEX.match("tmp-t-00000-of-00001.jsonl") + assert not DELTA_JSONL_REGEX.match("tmp-t-something") + assert not DELTA_JSONL_REGEX.match("tmp-t/") + assert not DELTA_JSONL_REGEX.match("tmp-t\\") + +def test_remote_object_store_file_regex(): + """Test the regex pattern for remote object store file paths""" + assert REMOTE_OBJECT_STORE_FILE_REGEX.match("s3://bucket-name/path/to/file") + assert REMOTE_OBJECT_STORE_FILE_REGEX.match("oci://bucket-name/path/to/file") + assert REMOTE_OBJECT_STORE_FILE_REGEX.match("gs://bucket-name/path/to/file") + assert REMOTE_OBJECT_STORE_FILE_REGEX.match("dbfs:/Volumes/path/to/file") + assert not REMOTE_OBJECT_STORE_FILE_REGEX.match("https://bucket-name/path/to/file") + assert not REMOTE_OBJECT_STORE_FILE_REGEX.match("/local/path/to/file") + assert not REMOTE_OBJECT_STORE_FILE_REGEX.match("s3:/bucket-name/path/to/file") + assert not REMOTE_OBJECT_STORE_FILE_REGEX.match("s3://bucket-name/path/to/file with spaces") + assert not REMOTE_OBJECT_STORE_FILE_REGEX.match("s3://bucket-name/path/to/file?") + +def test_hf_regex(): + """Test the regex pattern for Hugging Face dataset paths""" + assert HF_REGEX.match("dataset-name") + assert HF_REGEX.match("dataset_name") + assert HF_REGEX.match("dataset-name_with-mixed.characters-123") + assert not HF_REGEX.match("dataset/name") + assert not HF_REGEX.match("dataset\\name") + assert not HF_REGEX.match("dataset:name") + +def test_get_dataset_format(): + """Test the get_dataset_format function""" + + # Test delta format + assert get_dataset_format("tmp-t") == "delta" + assert get_dataset_format("tmp-t/") == "unknown" + + # Test remote object store format + assert get_dataset_format("s3://bucket-name/path/to/file") == "remote_object_store" + assert get_dataset_format("oci://bucket-name/path/to/file") == "remote_object_store" + assert get_dataset_format("gs://bucket-name/path/to/file") == "remote_object_store" + assert get_dataset_format("dbfs:/Volumes/path/to/file") == "remote_object_store" + + # Test Hugging Face format + assert get_dataset_format("dataset-name") == "hugging_face" + assert get_dataset_format("dataset_name") == "hugging_face" + assert get_dataset_format("dataset-name_with-mixed.characters-123") == "hugging_face" + + # Test unknown format + assert get_dataset_format("/local/path/to/file") == "unknown" + assert get_dataset_format("s3:/bucket-name/path/to/file") == "unknown" + assert get_dataset_format("dataset:name") == "unknown" + def calculate_file_hash(filepath: str) -> str: with open(filepath, "rb") as f: file_hash = hashlib.sha256(f.read()).hexdigest() From 87d0279a7043916f47e47175082d3d6861feb7fd Mon Sep 17 00:00:00 2001 From: Jimmy Xu Date: Mon, 28 Oct 2024 16:08:03 -0400 Subject: [PATCH 07/16] remove hf support Signed-off-by: Jimmy Xu --- .../command_utils/data_prep/split_eval_set.py | 27 ++--------------- .../data_prep/test_split_eval_set.py | 30 ++----------------- 2 files changed, 4 insertions(+), 53 deletions(-) diff --git a/llmfoundry/command_utils/data_prep/split_eval_set.py b/llmfoundry/command_utils/data_prep/split_eval_set.py index c376f20184..59d3b43432 100644 --- a/llmfoundry/command_utils/data_prep/split_eval_set.py +++ b/llmfoundry/command_utils/data_prep/split_eval_set.py @@ -4,14 +4,11 @@ import logging import os import re -import json import contextlib -import datasets as hf_datasets import numpy as np from typing import Optional import composer.utils as utils -from llmfoundry.data.finetuning.tasks import maybe_safe_download_hf_data log = logging.getLogger(__name__) @@ -19,7 +16,6 @@ REMOTE_OBJECT_STORE_FILE_REGEX = re.compile( r"^((s3|oci|gs):\/\/|dbfs:\/Volumes\/)[/a-zA-Z0-9 ()_\-.]+$" ) -HF_REGEX = re.compile(r"^[/a-zA-Z0-9 ()_\-.]+$") def get_dataset_format(data_path_folder: str) -> str: """ @@ -35,22 +31,19 @@ def get_dataset_format(data_path_folder: str) -> str: return "delta" if REMOTE_OBJECT_STORE_FILE_REGEX.match(data_path_folder): return "remote_object_store" - if HF_REGEX.match(data_path_folder): - return "hugging_face" return "unknown" TEMP_DIR = "tmp-split" def maybe_download_data_as_json(data_path_folder: str, data_path_split: str) -> str: """ - Prepares dataset as a local JSONL file. Downloads from remote object store or HF if necessary. + Prepares dataset as a local JSONL file. Downloads from remote object store if necessary. This function is intended to be invoked by DBX Finetuning. - Thus, it assumes the provided data is in one of three formats: + Thus, it assumes the provided data is: 1. A Delta table converted to JSONL at 'tmp-t/{data_path_split}-00000-of-00001.jsonl` using the 'llmfoundry.scripts.convert_delta_to_json.py' script. 2. A JSONL stored as a remote object store file (e.g. S3, OCI, GCS) - 3. A Hugging Face dataset Args: data_path_folder (str): Path to the training dataset folder @@ -75,22 +68,6 @@ def maybe_download_data_as_json(data_path_folder: str, data_path_split: str) -> data_path = os.path.join(TEMP_DIR, f"{data_path_split}.jsonl") utils.get_file(remote_path, data_path, overwrite=True) - elif dataset_format == "hugging_face": - log.info( - f"Downloading dataset from Hugging Face: {data_path_folder} with split {data_path_split}" - ) - # TODO: maybe add support for HF kwargs - local_hf_path = maybe_safe_download_hf_data(data_path_folder) - # convert dataset split to JSONL - dataset = hf_datasets.load_dataset( - local_hf_path, - split=data_path_split, - ) - data_path = os.path.join(TEMP_DIR, f"{data_path_split}.jsonl") - with open(data_path, "w") as f: - for example in dataset: - f.write(json.dumps(example) + "\n") - else: raise ValueError( f"Encountered unknown data path format when splitting dataset: {data_path_folder} with split {data_path_split}" diff --git a/tests/a_scripts/data_prep/test_split_eval_set.py b/tests/a_scripts/data_prep/test_split_eval_set.py index 7596debfc1..cb78729648 100644 --- a/tests/a_scripts/data_prep/test_split_eval_set.py +++ b/tests/a_scripts/data_prep/test_split_eval_set.py @@ -5,7 +5,7 @@ from unittest.mock import patch from llmfoundry.command_utils import split_eval_set_from_args, split_examples -from llmfoundry.command_utils.data_prep.split_eval_set import DELTA_JSONL_REGEX, get_dataset_format, HF_REGEX, REMOTE_OBJECT_STORE_FILE_REGEX +from llmfoundry.command_utils.data_prep.split_eval_set import DELTA_JSONL_REGEX, get_dataset_format, REMOTE_OBJECT_STORE_FILE_REGEX # Default values OUTPUT_DIR = "tmp-split" @@ -30,21 +30,12 @@ def test_remote_object_store_file_regex(): assert REMOTE_OBJECT_STORE_FILE_REGEX.match("oci://bucket-name/path/to/file") assert REMOTE_OBJECT_STORE_FILE_REGEX.match("gs://bucket-name/path/to/file") assert REMOTE_OBJECT_STORE_FILE_REGEX.match("dbfs:/Volumes/path/to/file") + assert REMOTE_OBJECT_STORE_FILE_REGEX.match("s3://bucket-name/path/to/file with spaces") assert not REMOTE_OBJECT_STORE_FILE_REGEX.match("https://bucket-name/path/to/file") assert not REMOTE_OBJECT_STORE_FILE_REGEX.match("/local/path/to/file") assert not REMOTE_OBJECT_STORE_FILE_REGEX.match("s3:/bucket-name/path/to/file") - assert not REMOTE_OBJECT_STORE_FILE_REGEX.match("s3://bucket-name/path/to/file with spaces") assert not REMOTE_OBJECT_STORE_FILE_REGEX.match("s3://bucket-name/path/to/file?") -def test_hf_regex(): - """Test the regex pattern for Hugging Face dataset paths""" - assert HF_REGEX.match("dataset-name") - assert HF_REGEX.match("dataset_name") - assert HF_REGEX.match("dataset-name_with-mixed.characters-123") - assert not HF_REGEX.match("dataset/name") - assert not HF_REGEX.match("dataset\\name") - assert not HF_REGEX.match("dataset:name") - def test_get_dataset_format(): """Test the get_dataset_format function""" @@ -58,11 +49,6 @@ def test_get_dataset_format(): assert get_dataset_format("gs://bucket-name/path/to/file") == "remote_object_store" assert get_dataset_format("dbfs:/Volumes/path/to/file") == "remote_object_store" - # Test Hugging Face format - assert get_dataset_format("dataset-name") == "hugging_face" - assert get_dataset_format("dataset_name") == "hugging_face" - assert get_dataset_format("dataset-name_with-mixed.characters-123") == "hugging_face" - # Test unknown format assert get_dataset_format("/local/path/to/file") == "unknown" assert get_dataset_format("s3:/bucket-name/path/to/file") == "unknown" @@ -170,18 +156,6 @@ def test_seed_consistency(): assert eval_hash_1 != eval_hash_3 -def test_hf_data_split(): - """Test splitting a dataset from Hugging Face""" - output_path = os.path.join(OUTPUT_DIR, "hf-split-test") - split_eval_set_from_args( - "databricks/databricks-dolly-15k", "train", output_path, EVAL_SPLIT_RATIO - ) - assert os.path.isfile(os.path.join(output_path, "train.jsonl")) - assert os.path.isfile(os.path.join(output_path, "eval.jsonl")) - assert count_lines(os.path.join(output_path, "train.jsonl")) > 0 - assert count_lines(os.path.join(output_path, "eval.jsonl")) > 0 - - def _mock_get_file(remote_path: str, data_path: str, overwrite: bool): with open(data_path, "w") as f: for i in range(1000): From 81f3b8fdfcc578681a71ca294960bac8bfb52e8b Mon Sep 17 00:00:00 2001 From: Jimmy Xu Date: Mon, 28 Oct 2024 17:03:55 -0400 Subject: [PATCH 08/16] fix dataloader test? Signed-off-by: Jimmy Xu --- llmfoundry/data/finetuning/tasks.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 490b10a7c9..5e744cde73 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -731,11 +731,12 @@ def maybe_safe_download_hf_data( hf_kwargs = {} if not os.path.isdir(dataset_name): + # dataset_name is not a local dir path, download if needed. local_dataset_dir = os.path.join( tempfile.mkdtemp(), dataset_name, ) - + log.debug( f'Downloading dataset {dataset_name} to {local_dataset_dir}.', ) @@ -753,7 +754,7 @@ def maybe_safe_download_hf_data( local_dir_use_symlinks=False, local_dir=local_dataset_dir, ) - if _is_empty_or_nonexistent(dirpath=dataset_name): + if _is_empty_or_nonexistent(dirpath=local_dataset_dir): log.error("Failed to safely load the dataset from HF Hub.") raise InvalidFileExtensionError( dataset_name, From d14fde5a45f0e09f506b143ec45e1a3c62162895 Mon Sep 17 00:00:00 2001 From: Jimmy Xu Date: Mon, 28 Oct 2024 17:39:27 -0400 Subject: [PATCH 09/16] lint Signed-off-by: Jimmy Xu --- .../command_utils/data_prep/split_eval_set.py | 77 +++--- llmfoundry/data/finetuning/tasks.py | 23 +- scripts/data_prep/split_eval_set.py | 34 +-- .../data_prep/test_split_eval_set.py | 225 +++++++++++------- 4 files changed, 213 insertions(+), 146 deletions(-) diff --git a/llmfoundry/command_utils/data_prep/split_eval_set.py b/llmfoundry/command_utils/data_prep/split_eval_set.py index 59d3b43432..223b94241f 100644 --- a/llmfoundry/command_utils/data_prep/split_eval_set.py +++ b/llmfoundry/command_utils/data_prep/split_eval_set.py @@ -1,25 +1,25 @@ # Copyright 2024 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +import contextlib import logging import os import re -import contextlib -import numpy as np from typing import Optional import composer.utils as utils +import numpy as np log = logging.getLogger(__name__) -DELTA_JSONL_REGEX = re.compile(r"^tmp-t$") +DELTA_JSONL_REGEX = re.compile(r'^tmp-t$') REMOTE_OBJECT_STORE_FILE_REGEX = re.compile( - r"^((s3|oci|gs):\/\/|dbfs:\/Volumes\/)[/a-zA-Z0-9 ()_\-.]+$" + r'^((s3|oci|gs):\/\/|dbfs:\/Volumes\/)[/a-zA-Z0-9 ()_\-.]+$', ) + def get_dataset_format(data_path_folder: str) -> str: - """ - Determine the format of the dataset from the provided data path + """Determine the format of the dataset from the provided data path. Args: data_path_folder (str): Path to the training dataset folder @@ -28,16 +28,20 @@ def get_dataset_format(data_path_folder: str) -> str: str: The format of the dataset """ if DELTA_JSONL_REGEX.match(data_path_folder): - return "delta" + return 'delta' if REMOTE_OBJECT_STORE_FILE_REGEX.match(data_path_folder): - return "remote_object_store" - return "unknown" + return 'remote_object_store' + return 'unknown' -TEMP_DIR = "tmp-split" -def maybe_download_data_as_json(data_path_folder: str, data_path_split: str) -> str: - """ - Prepares dataset as a local JSONL file. Downloads from remote object store if necessary. +TEMP_DIR = 'tmp-split' + + +def maybe_download_data_as_json( + data_path_folder: str, data_path_split: str +) -> str: + """Prepares dataset as a local JSONL file. Downloads from remote object + store if necessary. This function is intended to be invoked by DBX Finetuning. Thus, it assumes the provided data is: @@ -56,26 +60,30 @@ def maybe_download_data_as_json(data_path_folder: str, data_path_split: str) -> dataset_format = get_dataset_format(data_path_folder) - if dataset_format == "delta": - log.info(f"Dataset is converted from Delta table. Using local file {data_path_folder}") - data_path = os.path.join(data_path_folder, f"{data_path_split}-00000-of-00001.jsonl") + if dataset_format == 'delta': + log.info( + f'Dataset is converted from Delta table. Using local file {data_path_folder}' + ) + data_path = os.path.join( + data_path_folder, f'{data_path_split}-00000-of-00001.jsonl' + ) - elif dataset_format == "remote_object_store": + elif dataset_format == 'remote_object_store': log.info( - f"Downloading dataset from remote object store: {data_path_folder}{data_path_split}.jsonl" + f'Downloading dataset from remote object store: {data_path_folder}{data_path_split}.jsonl', ) - remote_path = f"{data_path_folder}/{data_path_split}.jsonl" - data_path = os.path.join(TEMP_DIR, f"{data_path_split}.jsonl") + remote_path = f'{data_path_folder}/{data_path_split}.jsonl' + data_path = os.path.join(TEMP_DIR, f'{data_path_split}.jsonl') utils.get_file(remote_path, data_path, overwrite=True) else: raise ValueError( - f"Encountered unknown data path format when splitting dataset: {data_path_folder} with split {data_path_split}" + f'Encountered unknown data path format when splitting dataset: {data_path_folder} with split {data_path_split}', ) if not os.path.exists(data_path): raise FileNotFoundError( - f"Expected dataset file at {data_path} for splitting, but it does not exist." + f'Expected dataset file at {data_path} for splitting, but it does not exist.', ) return data_path @@ -83,7 +91,7 @@ def maybe_download_data_as_json(data_path_folder: str, data_path_split: str) -> @contextlib.contextmanager def temp_seed(seed: int): - log.info(f"Setting random seed to {seed}") + log.info(f'Setting random seed to {seed}') state = np.random.get_state() np.random.seed(seed) try: @@ -99,8 +107,7 @@ def split_examples( max_eval_samples: Optional[int] = None, seed: Optional[int] = None, ) -> None: - """ - Splits the dataset into training and evaluation sets. + """Splits the dataset into training and evaluation sets. Args: data_path (str): Path to the training dataset (local jsonl file) @@ -113,7 +120,7 @@ def split_examples( # first pass: count total number of lines and determine sample size total_lines = 0 - with open(data_path, "r") as infile: + with open(data_path, 'r') as infile: for _ in infile: total_lines += 1 sample_size = int(eval_split_ratio * total_lines) @@ -125,9 +132,12 @@ def split_examples( sample_indices = set(np.argsort(random_numbers)[:sample_size]) # second pass: sample indices - with open(data_path, "r") as infile, open( - os.path.join(output_path, "train.jsonl"), "w" - ) as train_outfile, open(os.path.join(output_path, "eval.jsonl"), "w") as eval_outfile: + with open(data_path, 'r') as infile, open( + os.path.join(output_path, 'train.jsonl'), + 'w', + ) as train_outfile, open( + os.path.join(output_path, 'eval.jsonl'), 'w' + ) as eval_outfile: for idx, line in enumerate(infile): if idx in sample_indices: eval_outfile.write(line) @@ -135,7 +145,7 @@ def split_examples( train_outfile.write(line) log.info( - f"Split {data_path} into train set of size {total_lines - sample_size} and eval set of size {sample_size}." + f'Split {data_path} into train set of size {total_lines - sample_size} and eval set of size {sample_size}.', ) @@ -147,8 +157,7 @@ def split_eval_set_from_args( max_eval_samples: Optional[int] = None, seed: Optional[int] = None, ) -> None: - """ - A wrapper for split_eval_set that parses arguments + """A wrapper for split_eval_set that parses arguments. Args: data_path_folder (str): Path to the training dataset folder @@ -159,4 +168,6 @@ def split_eval_set_from_args( seed (int): Random seed for splitting the dataset """ data_path = maybe_download_data_as_json(data_path_folder, data_path_split) - split_examples(data_path, output_path, eval_split_ratio, max_eval_samples, seed) + split_examples( + data_path, output_path, eval_split_ratio, max_eval_samples, seed + ) diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 5e744cde73..68bd1f5a23 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -711,16 +711,16 @@ def state_dict(self, num_samples: int, num_samples=num_samples, from_beginning=from_beginning, ) - + + def maybe_safe_download_hf_data( dataset_name: str, - hf_kwargs: Optional[dict[str, Any]] = None + hf_kwargs: Optional[dict[str, Any]] = None, ) -> str: - """ - Download a HuggingFace dataset locally if it does not already exist. + """Download a HuggingFace dataset locally if it does not already exist. Args: - dataset_name (str): The name of the HuggingFace dataset to use. Can be a remote http(s) + dataset_name (str): The name of the HuggingFace dataset to use. Can be a remote http(s) directory or object store bucket containing the file {split}.jsonl. hf_kwargs (dict, optional): Additional kwargs to pass to `datasets.load_dataset`. @@ -746,16 +746,14 @@ def maybe_safe_download_hf_data( hf_hub.snapshot_download( dataset_name, repo_type='dataset', - allow_patterns=[ - '*' + ext for ext in SUPPORTED_EXTENSIONS - ], + allow_patterns=['*' + ext for ext in SUPPORTED_EXTENSIONS], token=hf_kwargs.get('token', None), revision=hf_kwargs.get('revision', None), local_dir_use_symlinks=False, local_dir=local_dataset_dir, ) if _is_empty_or_nonexistent(dirpath=local_dataset_dir): - log.error("Failed to safely load the dataset from HF Hub.") + log.error('Failed to safely load the dataset from HF Hub.') raise InvalidFileExtensionError( dataset_name, SUPPORTED_EXTENSIONS, @@ -767,15 +765,13 @@ def maybe_safe_download_hf_data( dataset_name = os.path.abspath(dataset_name) # Check that the directory contains only allowed file types. - dataset_files = [ - f for _, _, files in os.walk(dataset_name) for f in files - ] + dataset_files = [f for _, _, files in os.walk(dataset_name) for f in files] if not all( Path(f).suffix in SUPPORTED_EXTENSIONS + HUGGINGFACE_FOLDER_EXTENSIONS or f == '.gitignore' for f in dataset_files ): - log.error(f"Invalid file extension found in dataset during safe load.") + log.error(f'Invalid file extension found in dataset during safe load.') raise InvalidFileExtensionError( dataset_name, SUPPORTED_EXTENSIONS, @@ -783,6 +779,7 @@ def maybe_safe_download_hf_data( return dataset_name + class DatasetConstructor: def __init__(self): diff --git a/scripts/data_prep/split_eval_set.py b/scripts/data_prep/split_eval_set.py index 8631fceb4f..42aa1c82f0 100644 --- a/scripts/data_prep/split_eval_set.py +++ b/scripts/data_prep/split_eval_set.py @@ -5,44 +5,50 @@ from llmfoundry.command_utils import split_eval_set_from_args - -if __name__ == "__main__": +if __name__ == '__main__': parser = ArgumentParser( - description="Split training dataset into train and eval sets", + description='Split training dataset into train and eval sets', ) parser.add_argument( - "--data_path_folder", required=True, type=str, help="Path to the training dataset folder" + '--data_path_folder', + required=True, + type=str, + help='Path to the training dataset folder', ) parser.add_argument( - "--data_path_split", required=True, type=str, help="Path to the training dataset split" + '--data_path_split', + required=True, + type=str, + help='Path to the training dataset split', ) parser.add_argument( - "--output_path", + '--output_path', required=False, type=str, - default="/tmp-split", - help="Path to save the split dataset", + default='/tmp-split', + help='Path to save the split dataset', ) parser.add_argument( - "--eval_split_ratio", + '--eval_split_ratio', required=False, type=float, default=0.1, - help="Ratio of the dataset to use for evaluation. The remainder will be used for training", + help= + 'Ratio of the dataset to use for evaluation. The remainder will be used for training', ) parser.add_argument( - "--max_eval_samples", + '--max_eval_samples', required=False, type=int, default=100, - help="Maximum number of samples to include in the eval set", + help='Maximum number of samples to include in the eval set', ) parser.add_argument( - "--seed", + '--seed', required=False, type=int, default=42, - help="Random seed for splitting the dataset", + help='Random seed for splitting the dataset', ) args = parser.parse_args() split_eval_set_from_args( diff --git a/tests/a_scripts/data_prep/test_split_eval_set.py b/tests/a_scripts/data_prep/test_split_eval_set.py index cb78729648..a9e09d4ce3 100644 --- a/tests/a_scripts/data_prep/test_split_eval_set.py +++ b/tests/a_scripts/data_prep/test_split_eval_set.py @@ -1,102 +1,134 @@ -import os -import json -import pytest +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + import hashlib +import json +import os from unittest.mock import patch +import pytest + from llmfoundry.command_utils import split_eval_set_from_args, split_examples -from llmfoundry.command_utils.data_prep.split_eval_set import DELTA_JSONL_REGEX, get_dataset_format, REMOTE_OBJECT_STORE_FILE_REGEX +from llmfoundry.command_utils.data_prep.split_eval_set import ( + DELTA_JSONL_REGEX, REMOTE_OBJECT_STORE_FILE_REGEX, get_dataset_format,) # Default values -OUTPUT_DIR = "tmp-split" -TMPT_DIR = "tmp-t" -DATA_PATH_SPLIT = "train" +OUTPUT_DIR = 'tmp-split' +TMPT_DIR = 'tmp-t' +DATA_PATH_SPLIT = 'train' EVAL_SPLIT_RATIO = 0.1 -DEFAULT_FILE = TMPT_DIR + "/train-00000-of-00001.jsonl" +DEFAULT_FILE = TMPT_DIR + '/train-00000-of-00001.jsonl' def test_delta_jsonl_regex(): - """Test the regex pattern matches tmp-t exactly""" - assert DELTA_JSONL_REGEX.match("tmp-t") - assert not DELTA_JSONL_REGEX.match("/tmp-t") - assert not DELTA_JSONL_REGEX.match("tmp-t-00000-of-00001.jsonl") - assert not DELTA_JSONL_REGEX.match("tmp-t-something") - assert not DELTA_JSONL_REGEX.match("tmp-t/") - assert not DELTA_JSONL_REGEX.match("tmp-t\\") + """Test the regex pattern matches tmp-t exactly.""" + assert DELTA_JSONL_REGEX.match('tmp-t') + assert not DELTA_JSONL_REGEX.match('/tmp-t') + assert not DELTA_JSONL_REGEX.match('tmp-t-00000-of-00001.jsonl') + assert not DELTA_JSONL_REGEX.match('tmp-t-something') + assert not DELTA_JSONL_REGEX.match('tmp-t/') + assert not DELTA_JSONL_REGEX.match('tmp-t\\') + def test_remote_object_store_file_regex(): - """Test the regex pattern for remote object store file paths""" - assert REMOTE_OBJECT_STORE_FILE_REGEX.match("s3://bucket-name/path/to/file") - assert REMOTE_OBJECT_STORE_FILE_REGEX.match("oci://bucket-name/path/to/file") - assert REMOTE_OBJECT_STORE_FILE_REGEX.match("gs://bucket-name/path/to/file") - assert REMOTE_OBJECT_STORE_FILE_REGEX.match("dbfs:/Volumes/path/to/file") - assert REMOTE_OBJECT_STORE_FILE_REGEX.match("s3://bucket-name/path/to/file with spaces") - assert not REMOTE_OBJECT_STORE_FILE_REGEX.match("https://bucket-name/path/to/file") - assert not REMOTE_OBJECT_STORE_FILE_REGEX.match("/local/path/to/file") - assert not REMOTE_OBJECT_STORE_FILE_REGEX.match("s3:/bucket-name/path/to/file") - assert not REMOTE_OBJECT_STORE_FILE_REGEX.match("s3://bucket-name/path/to/file?") + """Test the regex pattern for remote object store file paths.""" + assert REMOTE_OBJECT_STORE_FILE_REGEX.match('s3://bucket-name/path/to/file') + assert REMOTE_OBJECT_STORE_FILE_REGEX.match( + 'oci://bucket-name/path/to/file' + ) + assert REMOTE_OBJECT_STORE_FILE_REGEX.match('gs://bucket-name/path/to/file') + assert REMOTE_OBJECT_STORE_FILE_REGEX.match('dbfs:/Volumes/path/to/file') + assert REMOTE_OBJECT_STORE_FILE_REGEX.match( + 's3://bucket-name/path/to/file with spaces' + ) + assert not REMOTE_OBJECT_STORE_FILE_REGEX.match( + 'https://bucket-name/path/to/file' + ) + assert not REMOTE_OBJECT_STORE_FILE_REGEX.match('/local/path/to/file') + assert not REMOTE_OBJECT_STORE_FILE_REGEX.match( + 's3:/bucket-name/path/to/file' + ) + assert not REMOTE_OBJECT_STORE_FILE_REGEX.match( + 's3://bucket-name/path/to/file?' + ) -def test_get_dataset_format(): - """Test the get_dataset_format function""" +def test_get_dataset_format(): + """Test the get_dataset_format function.""" # Test delta format - assert get_dataset_format("tmp-t") == "delta" - assert get_dataset_format("tmp-t/") == "unknown" + assert get_dataset_format('tmp-t') == 'delta' + assert get_dataset_format('tmp-t/') == 'unknown' # Test remote object store format - assert get_dataset_format("s3://bucket-name/path/to/file") == "remote_object_store" - assert get_dataset_format("oci://bucket-name/path/to/file") == "remote_object_store" - assert get_dataset_format("gs://bucket-name/path/to/file") == "remote_object_store" - assert get_dataset_format("dbfs:/Volumes/path/to/file") == "remote_object_store" + assert get_dataset_format( + 's3://bucket-name/path/to/file' + ) == 'remote_object_store' + assert get_dataset_format( + 'oci://bucket-name/path/to/file' + ) == 'remote_object_store' + assert get_dataset_format( + 'gs://bucket-name/path/to/file' + ) == 'remote_object_store' + assert get_dataset_format( + 'dbfs:/Volumes/path/to/file' + ) == 'remote_object_store' # Test unknown format - assert get_dataset_format("/local/path/to/file") == "unknown" - assert get_dataset_format("s3:/bucket-name/path/to/file") == "unknown" - assert get_dataset_format("dataset:name") == "unknown" + assert get_dataset_format('/local/path/to/file') == 'unknown' + assert get_dataset_format('s3:/bucket-name/path/to/file') == 'unknown' + assert get_dataset_format('dataset:name') == 'unknown' + def calculate_file_hash(filepath: str) -> str: - with open(filepath, "rb") as f: + with open(filepath, 'rb') as f: file_hash = hashlib.sha256(f.read()).hexdigest() return file_hash def count_lines(filepath: str) -> int: - with open(filepath, "r") as f: + with open(filepath, 'r') as f: return sum(1 for _ in f) -@pytest.fixture(scope="module", autouse=True) +@pytest.fixture(scope='module', autouse=True) def setup_and_teardown_module(): # Setup: create local testing file os.makedirs(TMPT_DIR, exist_ok=True) - with open(DEFAULT_FILE, "w") as f: + with open(DEFAULT_FILE, 'w') as f: for i in range(1000): - f.write(json.dumps({"prompt": "hello world " + str(i), "response": "hi you!"}) + "\n") + f.write( + json.dumps({ + 'prompt': 'hello world ' + str(i), + 'response': 'hi you!' + }) + '\n' + ) yield # Teardown: clean up output and tmp directories - os.system(f"rm -rf {OUTPUT_DIR}") - os.system(f"rm -rf {TMPT_DIR}") + os.system(f'rm -rf {OUTPUT_DIR}') + os.system(f'rm -rf {TMPT_DIR}') def test_basic_split(): - """Test basic functionality on local file""" - output_path = os.path.join(OUTPUT_DIR, "basic-test") - split_eval_set_from_args(TMPT_DIR, DATA_PATH_SPLIT, output_path, EVAL_SPLIT_RATIO) - assert os.path.isfile(os.path.join(output_path, "train.jsonl")) - assert os.path.isfile(os.path.join(output_path, "eval.jsonl")) + """Test basic functionality on local file.""" + output_path = os.path.join(OUTPUT_DIR, 'basic-test') + split_eval_set_from_args( + TMPT_DIR, DATA_PATH_SPLIT, output_path, EVAL_SPLIT_RATIO + ) + assert os.path.isfile(os.path.join(output_path, 'train.jsonl')) + assert os.path.isfile(os.path.join(output_path, 'eval.jsonl')) def test_basic_split_output_exists(): - """Test that split overwrites existing files in output directory""" - output_path = os.path.join(OUTPUT_DIR, "basic-test") + """Test that split overwrites existing files in output directory.""" + output_path = os.path.join(OUTPUT_DIR, 'basic-test') os.makedirs(output_path, exist_ok=True) - train_file = os.path.join(output_path, "train.jsonl") - eval_file = os.path.join(output_path, "eval.jsonl") - with open(train_file, "w") as f: - f.write("existing file train") - with open(eval_file, "w") as f: - f.write("existing file eval") + train_file = os.path.join(output_path, 'train.jsonl') + eval_file = os.path.join(output_path, 'eval.jsonl') + with open(train_file, 'w') as f: + f.write('existing file train') + with open(eval_file, 'w') as f: + f.write('existing file eval') old_train_hash = calculate_file_hash(train_file) old_eval_hash = calculate_file_hash(eval_file) split_eval_set_from_args( @@ -111,7 +143,7 @@ def test_basic_split_output_exists(): def test_max_eval_samples(): """Test case where max_eval_samples < eval_split_ratio * total samples""" - output_path = os.path.join(OUTPUT_DIR, "max-eval-test") + output_path = os.path.join(OUTPUT_DIR, 'max-eval-test') max_eval_samples = 50 split_eval_set_from_args( TMPT_DIR, @@ -120,72 +152,93 @@ def test_max_eval_samples(): EVAL_SPLIT_RATIO, max_eval_samples, ) - eval_lines = count_lines(os.path.join(output_path, "eval.jsonl")) + eval_lines = count_lines(os.path.join(output_path, 'eval.jsonl')) assert eval_lines == max_eval_samples def test_eval_split_ratio(): - """Test case where max_eval_samples is not used""" - output_path = os.path.join(OUTPUT_DIR, "eval-split-test") - split_eval_set_from_args(TMPT_DIR, DATA_PATH_SPLIT, output_path, EVAL_SPLIT_RATIO) + """Test case where max_eval_samples is not used.""" + output_path = os.path.join(OUTPUT_DIR, 'eval-split-test') + split_eval_set_from_args( + TMPT_DIR, DATA_PATH_SPLIT, output_path, EVAL_SPLIT_RATIO + ) original_data_lines = count_lines(DEFAULT_FILE) - eval_lines = count_lines(os.path.join(output_path, "eval.jsonl")) - assert abs(eval_lines - EVAL_SPLIT_RATIO * original_data_lines) < 1 # allow for rounding errors + eval_lines = count_lines(os.path.join(output_path, 'eval.jsonl')) + assert abs( + eval_lines - EVAL_SPLIT_RATIO * original_data_lines + ) < 1 # allow for rounding errors def test_seed_consistency(): - """Test if the same seed generates consistent splits""" - output_path_1 = os.path.join(OUTPUT_DIR, "seed-test-1") - output_path_2 = os.path.join(OUTPUT_DIR, "seed-test-2") + """Test if the same seed generates consistent splits.""" + output_path_1 = os.path.join(OUTPUT_DIR, 'seed-test-1') + output_path_2 = os.path.join(OUTPUT_DIR, 'seed-test-2') split_examples(DEFAULT_FILE, output_path_1, EVAL_SPLIT_RATIO, seed=12345) split_examples(DEFAULT_FILE, output_path_2, EVAL_SPLIT_RATIO, seed=12345) - train_hash_1 = calculate_file_hash(os.path.join(output_path_1, "train.jsonl")) - train_hash_2 = calculate_file_hash(os.path.join(output_path_2, "train.jsonl")) - eval_hash_1 = calculate_file_hash(os.path.join(output_path_1, "eval.jsonl")) - eval_hash_2 = calculate_file_hash(os.path.join(output_path_2, "eval.jsonl")) + train_hash_1 = calculate_file_hash( + os.path.join(output_path_1, 'train.jsonl') + ) + train_hash_2 = calculate_file_hash( + os.path.join(output_path_2, 'train.jsonl') + ) + eval_hash_1 = calculate_file_hash(os.path.join(output_path_1, 'eval.jsonl')) + eval_hash_2 = calculate_file_hash(os.path.join(output_path_2, 'eval.jsonl')) assert train_hash_1 == train_hash_2 assert eval_hash_1 == eval_hash_2 - output_path_3 = os.path.join(OUTPUT_DIR, "seed-test-3") + output_path_3 = os.path.join(OUTPUT_DIR, 'seed-test-3') split_examples(DEFAULT_FILE, output_path_3, EVAL_SPLIT_RATIO, seed=54321) - train_hash_3 = calculate_file_hash(os.path.join(output_path_3, "train.jsonl")) - eval_hash_3 = calculate_file_hash(os.path.join(output_path_3, "eval.jsonl")) + train_hash_3 = calculate_file_hash( + os.path.join(output_path_3, 'train.jsonl') + ) + eval_hash_3 = calculate_file_hash(os.path.join(output_path_3, 'eval.jsonl')) assert train_hash_1 != train_hash_3 assert eval_hash_1 != eval_hash_3 def _mock_get_file(remote_path: str, data_path: str, overwrite: bool): - with open(data_path, "w") as f: + with open(data_path, 'w') as f: for i in range(1000): - f.write(json.dumps({"prompt": "hello world " + str(i), "response": "hi you!"}) + "\n") + f.write( + json.dumps({ + 'prompt': 'hello world ' + str(i), + 'response': 'hi you!' + }) + '\n' + ) def test_remote_store_data_split(): - """Test splitting a dataset from a remote store""" - output_path = os.path.join(OUTPUT_DIR, "remote-split-test") - with patch("composer.utils.get_file", side_effect=_mock_get_file) as mock_get_file: + """Test splitting a dataset from a remote store.""" + output_path = os.path.join(OUTPUT_DIR, 'remote-split-test') + with patch( + 'composer.utils.get_file', side_effect=_mock_get_file + ) as mock_get_file: split_eval_set_from_args( - "dbfs:/Volumes/test/test/test.jsonl", - "unique-split-name", + 'dbfs:/Volumes/test/test/test.jsonl', + 'unique-split-name', output_path, EVAL_SPLIT_RATIO, ) mock_get_file.assert_called() - assert os.path.isfile(os.path.join(output_path, "train.jsonl")) - assert os.path.isfile(os.path.join(output_path, "eval.jsonl")) - assert count_lines(os.path.join(output_path, "train.jsonl")) > 0 - assert count_lines(os.path.join(output_path, "eval.jsonl")) > 0 + assert os.path.isfile(os.path.join(output_path, 'train.jsonl')) + assert os.path.isfile(os.path.join(output_path, 'eval.jsonl')) + assert count_lines(os.path.join(output_path, 'train.jsonl')) > 0 + assert count_lines(os.path.join(output_path, 'eval.jsonl')) > 0 def test_missing_delta_file_error(): # expects file 'TMPT_DIR/missing-00000-of-00001.jsonl with pytest.raises(FileNotFoundError): - split_eval_set_from_args(TMPT_DIR, "missing", OUTPUT_DIR, EVAL_SPLIT_RATIO) + split_eval_set_from_args( + TMPT_DIR, 'missing', OUTPUT_DIR, EVAL_SPLIT_RATIO + ) def test_unknown_file_format_error(): with pytest.raises(ValueError): - split_eval_set_from_args("s3:/path/to/file.jsonl", "train", OUTPUT_DIR, EVAL_SPLIT_RATIO) + split_eval_set_from_args( + 's3:/path/to/file.jsonl', 'train', OUTPUT_DIR, EVAL_SPLIT_RATIO + ) From e7cca17e8b59020b013e9202e0f9d78410d9ba2e Mon Sep 17 00:00:00 2001 From: Jimmy Xu Date: Mon, 28 Oct 2024 19:50:56 -0400 Subject: [PATCH 10/16] lint Signed-off-by: Jimmy Xu --- .../command_utils/data_prep/split_eval_set.py | 21 ++++--- .../data_prep/test_split_eval_set.py | 62 ++++++++++++------- 2 files changed, 53 insertions(+), 30 deletions(-) diff --git a/llmfoundry/command_utils/data_prep/split_eval_set.py b/llmfoundry/command_utils/data_prep/split_eval_set.py index 223b94241f..187425f420 100644 --- a/llmfoundry/command_utils/data_prep/split_eval_set.py +++ b/llmfoundry/command_utils/data_prep/split_eval_set.py @@ -38,11 +38,12 @@ def get_dataset_format(data_path_folder: str) -> str: def maybe_download_data_as_json( - data_path_folder: str, data_path_split: str + data_path_folder: str, + data_path_split: str, ) -> str: - """Prepares dataset as a local JSONL file. Downloads from remote object - store if necessary. + """Prepares dataset as a local JSONL file. + Downloads from remote object store if needed. This function is intended to be invoked by DBX Finetuning. Thus, it assumes the provided data is: 1. A Delta table converted to JSONL at 'tmp-t/{data_path_split}-00000-of-00001.jsonl` @@ -62,10 +63,11 @@ def maybe_download_data_as_json( if dataset_format == 'delta': log.info( - f'Dataset is converted from Delta table. Using local file {data_path_folder}' + f'Dataset is converted from Delta table. Using local file {data_path_folder}', ) data_path = os.path.join( - data_path_folder, f'{data_path_split}-00000-of-00001.jsonl' + data_path_folder, + f'{data_path_split}-00000-of-00001.jsonl', ) elif dataset_format == 'remote_object_store': @@ -136,7 +138,8 @@ def split_examples( os.path.join(output_path, 'train.jsonl'), 'w', ) as train_outfile, open( - os.path.join(output_path, 'eval.jsonl'), 'w' + os.path.join(output_path, 'eval.jsonl'), + 'w', ) as eval_outfile: for idx, line in enumerate(infile): if idx in sample_indices: @@ -169,5 +172,9 @@ def split_eval_set_from_args( """ data_path = maybe_download_data_as_json(data_path_folder, data_path_split) split_examples( - data_path, output_path, eval_split_ratio, max_eval_samples, seed + data_path, + output_path, + eval_split_ratio, + max_eval_samples, + seed, ) diff --git a/tests/a_scripts/data_prep/test_split_eval_set.py b/tests/a_scripts/data_prep/test_split_eval_set.py index a9e09d4ce3..f225d9f69c 100644 --- a/tests/a_scripts/data_prep/test_split_eval_set.py +++ b/tests/a_scripts/data_prep/test_split_eval_set.py @@ -10,7 +10,10 @@ from llmfoundry.command_utils import split_eval_set_from_args, split_examples from llmfoundry.command_utils.data_prep.split_eval_set import ( - DELTA_JSONL_REGEX, REMOTE_OBJECT_STORE_FILE_REGEX, get_dataset_format,) + DELTA_JSONL_REGEX, + REMOTE_OBJECT_STORE_FILE_REGEX, + get_dataset_format, +) # Default values OUTPUT_DIR = 'tmp-split' @@ -34,22 +37,22 @@ def test_remote_object_store_file_regex(): """Test the regex pattern for remote object store file paths.""" assert REMOTE_OBJECT_STORE_FILE_REGEX.match('s3://bucket-name/path/to/file') assert REMOTE_OBJECT_STORE_FILE_REGEX.match( - 'oci://bucket-name/path/to/file' + 'oci://bucket-name/path/to/file', ) assert REMOTE_OBJECT_STORE_FILE_REGEX.match('gs://bucket-name/path/to/file') assert REMOTE_OBJECT_STORE_FILE_REGEX.match('dbfs:/Volumes/path/to/file') assert REMOTE_OBJECT_STORE_FILE_REGEX.match( - 's3://bucket-name/path/to/file with spaces' + 's3://bucket-name/path/to/file with spaces', ) assert not REMOTE_OBJECT_STORE_FILE_REGEX.match( - 'https://bucket-name/path/to/file' + 'https://bucket-name/path/to/file', ) assert not REMOTE_OBJECT_STORE_FILE_REGEX.match('/local/path/to/file') assert not REMOTE_OBJECT_STORE_FILE_REGEX.match( - 's3:/bucket-name/path/to/file' + 's3:/bucket-name/path/to/file', ) assert not REMOTE_OBJECT_STORE_FILE_REGEX.match( - 's3://bucket-name/path/to/file?' + 's3://bucket-name/path/to/file?', ) @@ -61,16 +64,16 @@ def test_get_dataset_format(): # Test remote object store format assert get_dataset_format( - 's3://bucket-name/path/to/file' + 's3://bucket-name/path/to/file', ) == 'remote_object_store' assert get_dataset_format( - 'oci://bucket-name/path/to/file' + 'oci://bucket-name/path/to/file', ) == 'remote_object_store' assert get_dataset_format( - 'gs://bucket-name/path/to/file' + 'gs://bucket-name/path/to/file', ) == 'remote_object_store' assert get_dataset_format( - 'dbfs:/Volumes/path/to/file' + 'dbfs:/Volumes/path/to/file', ) == 'remote_object_store' # Test unknown format @@ -99,8 +102,8 @@ def setup_and_teardown_module(): f.write( json.dumps({ 'prompt': 'hello world ' + str(i), - 'response': 'hi you!' - }) + '\n' + 'response': 'hi you!', + }) + '\n', ) yield @@ -113,7 +116,10 @@ def test_basic_split(): """Test basic functionality on local file.""" output_path = os.path.join(OUTPUT_DIR, 'basic-test') split_eval_set_from_args( - TMPT_DIR, DATA_PATH_SPLIT, output_path, EVAL_SPLIT_RATIO + TMPT_DIR, + DATA_PATH_SPLIT, + output_path, + EVAL_SPLIT_RATIO, ) assert os.path.isfile(os.path.join(output_path, 'train.jsonl')) assert os.path.isfile(os.path.join(output_path, 'eval.jsonl')) @@ -160,12 +166,15 @@ def test_eval_split_ratio(): """Test case where max_eval_samples is not used.""" output_path = os.path.join(OUTPUT_DIR, 'eval-split-test') split_eval_set_from_args( - TMPT_DIR, DATA_PATH_SPLIT, output_path, EVAL_SPLIT_RATIO + TMPT_DIR, + DATA_PATH_SPLIT, + output_path, + EVAL_SPLIT_RATIO, ) original_data_lines = count_lines(DEFAULT_FILE) eval_lines = count_lines(os.path.join(output_path, 'eval.jsonl')) assert abs( - eval_lines - EVAL_SPLIT_RATIO * original_data_lines + eval_lines - EVAL_SPLIT_RATIO * original_data_lines, ) < 1 # allow for rounding errors @@ -176,10 +185,10 @@ def test_seed_consistency(): split_examples(DEFAULT_FILE, output_path_1, EVAL_SPLIT_RATIO, seed=12345) split_examples(DEFAULT_FILE, output_path_2, EVAL_SPLIT_RATIO, seed=12345) train_hash_1 = calculate_file_hash( - os.path.join(output_path_1, 'train.jsonl') + os.path.join(output_path_1, 'train.jsonl'), ) train_hash_2 = calculate_file_hash( - os.path.join(output_path_2, 'train.jsonl') + os.path.join(output_path_2, 'train.jsonl'), ) eval_hash_1 = calculate_file_hash(os.path.join(output_path_1, 'eval.jsonl')) eval_hash_2 = calculate_file_hash(os.path.join(output_path_2, 'eval.jsonl')) @@ -190,7 +199,7 @@ def test_seed_consistency(): output_path_3 = os.path.join(OUTPUT_DIR, 'seed-test-3') split_examples(DEFAULT_FILE, output_path_3, EVAL_SPLIT_RATIO, seed=54321) train_hash_3 = calculate_file_hash( - os.path.join(output_path_3, 'train.jsonl') + os.path.join(output_path_3, 'train.jsonl'), ) eval_hash_3 = calculate_file_hash(os.path.join(output_path_3, 'eval.jsonl')) @@ -204,8 +213,8 @@ def _mock_get_file(remote_path: str, data_path: str, overwrite: bool): f.write( json.dumps({ 'prompt': 'hello world ' + str(i), - 'response': 'hi you!' - }) + '\n' + 'response': 'hi you!', + }) + '\n', ) @@ -213,7 +222,8 @@ def test_remote_store_data_split(): """Test splitting a dataset from a remote store.""" output_path = os.path.join(OUTPUT_DIR, 'remote-split-test') with patch( - 'composer.utils.get_file', side_effect=_mock_get_file + 'composer.utils.get_file', + side_effect=_mock_get_file, ) as mock_get_file: split_eval_set_from_args( 'dbfs:/Volumes/test/test/test.jsonl', @@ -233,12 +243,18 @@ def test_missing_delta_file_error(): # expects file 'TMPT_DIR/missing-00000-of-00001.jsonl with pytest.raises(FileNotFoundError): split_eval_set_from_args( - TMPT_DIR, 'missing', OUTPUT_DIR, EVAL_SPLIT_RATIO + TMPT_DIR, + 'missing', + OUTPUT_DIR, + EVAL_SPLIT_RATIO, ) def test_unknown_file_format_error(): with pytest.raises(ValueError): split_eval_set_from_args( - 's3:/path/to/file.jsonl', 'train', OUTPUT_DIR, EVAL_SPLIT_RATIO + 's3:/path/to/file.jsonl', + 'train', + OUTPUT_DIR, + EVAL_SPLIT_RATIO, ) From a4327bb9733e43482e8ebe812d1143fd5a9e58d5 Mon Sep 17 00:00:00 2001 From: Jimmy Xu Date: Mon, 28 Oct 2024 20:12:38 -0400 Subject: [PATCH 11/16] some comments Signed-off-by: Jimmy Xu --- .../command_utils/data_prep/split_eval_set.py | 11 +-- .../data_prep/test_split_eval_set.py | 88 +++++++++---------- 2 files changed, 46 insertions(+), 53 deletions(-) diff --git a/llmfoundry/command_utils/data_prep/split_eval_set.py b/llmfoundry/command_utils/data_prep/split_eval_set.py index 187425f420..51a8c28d7d 100644 --- a/llmfoundry/command_utils/data_prep/split_eval_set.py +++ b/llmfoundry/command_utils/data_prep/split_eval_set.py @@ -37,7 +37,7 @@ def get_dataset_format(data_path_folder: str) -> str: TEMP_DIR = 'tmp-split' -def maybe_download_data_as_json( +def maybe_download_data_as_jsonl( data_path_folder: str, data_path_split: str, ) -> str: @@ -129,9 +129,10 @@ def split_examples( if max_eval_samples is not None: sample_size = min(sample_size, max_eval_samples) - with temp_seed(seed) if seed is not None else contextlib.nullcontext(): - random_numbers = np.random.rand(total_lines) - sample_indices = set(np.argsort(random_numbers)[:sample_size]) + # Use a new RNG instance with the provided seed + rng = np.random.default_rng(seed) + random_numbers = rng.random(total_lines) + sample_indices = set(np.argsort(random_numbers)[:sample_size]) # second pass: sample indices with open(data_path, 'r') as infile, open( @@ -170,7 +171,7 @@ def split_eval_set_from_args( max_eval_samples (int): Maximum number of samples to include in the eval set. If None, all eval_split_ratio * train_dataset_size samples will be used seed (int): Random seed for splitting the dataset """ - data_path = maybe_download_data_as_json(data_path_folder, data_path_split) + data_path = maybe_download_data_as_jsonl(data_path_folder, data_path_split) split_examples( data_path, output_path, diff --git a/tests/a_scripts/data_prep/test_split_eval_set.py b/tests/a_scripts/data_prep/test_split_eval_set.py index f225d9f69c..bc881690c6 100644 --- a/tests/a_scripts/data_prep/test_split_eval_set.py +++ b/tests/a_scripts/data_prep/test_split_eval_set.py @@ -23,63 +23,55 @@ DEFAULT_FILE = TMPT_DIR + '/train-00000-of-00001.jsonl' -def test_delta_jsonl_regex(): +@pytest.mark.parametrize("test_input: str, expected: bool", [ + ('tmp-t', True), + ('/tmp-t', False), + ('tmp-t-00000-of-00001.jsonl', False), + ('tmp-t-something', False), + ('tmp-t/', False), + ('tmp-t\\', False) +]) +def test_delta_jsonl_regex(test_input: str, expected: bool) -> None: """Test the regex pattern matches tmp-t exactly.""" - assert DELTA_JSONL_REGEX.match('tmp-t') - assert not DELTA_JSONL_REGEX.match('/tmp-t') - assert not DELTA_JSONL_REGEX.match('tmp-t-00000-of-00001.jsonl') - assert not DELTA_JSONL_REGEX.match('tmp-t-something') - assert not DELTA_JSONL_REGEX.match('tmp-t/') - assert not DELTA_JSONL_REGEX.match('tmp-t\\') - - -def test_remote_object_store_file_regex(): + assert bool(DELTA_JSONL_REGEX.match(test_input)) == expected + + + +@pytest.mark.parametrize("test_input: str, expected: bool", [ + ('s3://bucket-name/path/to/file', True), + ('oci://bucket-name/path/to/file', True), + ('gs://bucket-name/path/to/file', True), + ('dbfs:/Volumes/path/to/file', True), + ('s3://bucket-name/path/to/file with spaces', True), + ('https://bucket-name/path/to/file', False), + ('/local/path/to/file', False), + ('s3:/bucket-name/path/to/file', False), + ('s3://bucket-name/path/to/file?', False) +]) +def test_remote_object_store_file_regex(test_input: str, expected: bool) -> None: """Test the regex pattern for remote object store file paths.""" - assert REMOTE_OBJECT_STORE_FILE_REGEX.match('s3://bucket-name/path/to/file') - assert REMOTE_OBJECT_STORE_FILE_REGEX.match( - 'oci://bucket-name/path/to/file', - ) - assert REMOTE_OBJECT_STORE_FILE_REGEX.match('gs://bucket-name/path/to/file') - assert REMOTE_OBJECT_STORE_FILE_REGEX.match('dbfs:/Volumes/path/to/file') - assert REMOTE_OBJECT_STORE_FILE_REGEX.match( - 's3://bucket-name/path/to/file with spaces', - ) - assert not REMOTE_OBJECT_STORE_FILE_REGEX.match( - 'https://bucket-name/path/to/file', - ) - assert not REMOTE_OBJECT_STORE_FILE_REGEX.match('/local/path/to/file') - assert not REMOTE_OBJECT_STORE_FILE_REGEX.match( - 's3:/bucket-name/path/to/file', - ) - assert not REMOTE_OBJECT_STORE_FILE_REGEX.match( - 's3://bucket-name/path/to/file?', - ) + assert bool(REMOTE_OBJECT_STORE_FILE_REGEX.match(test_input)) == expected -def test_get_dataset_format(): - """Test the get_dataset_format function.""" +@pytest.mark.parametrize("test_input: str, expected: str", [ # Test delta format - assert get_dataset_format('tmp-t') == 'delta' - assert get_dataset_format('tmp-t/') == 'unknown' + ('tmp-t', 'delta'), + ('tmp-t/', 'unknown'), # Test remote object store format - assert get_dataset_format( - 's3://bucket-name/path/to/file', - ) == 'remote_object_store' - assert get_dataset_format( - 'oci://bucket-name/path/to/file', - ) == 'remote_object_store' - assert get_dataset_format( - 'gs://bucket-name/path/to/file', - ) == 'remote_object_store' - assert get_dataset_format( - 'dbfs:/Volumes/path/to/file', - ) == 'remote_object_store' + ('s3://bucket-name/path/to/file', 'remote_object_store'), + ('oci://bucket-name/path/to/file', 'remote_object_store'), + ('gs://bucket-name/path/to/file', 'remote_object_store'), + ('dbfs:/Volumes/path/to/file', 'remote_object_store'), # Test unknown format - assert get_dataset_format('/local/path/to/file') == 'unknown' - assert get_dataset_format('s3:/bucket-name/path/to/file') == 'unknown' - assert get_dataset_format('dataset:name') == 'unknown' + ('/local/path/to/file', 'unknown'), + ('s3:/bucket-name/path/to/file', 'unknown'), + ('dataset:name', 'unknown') +]) +def test_get_dataset_format(test_input: str, expected: str) -> None: + """Test the get_dataset_format function.""" + assert get_dataset_format(test_input) == expected def calculate_file_hash(filepath: str) -> str: From 39676be52a10cf8887754ebbd71529575fa4d067 Mon Sep 17 00:00:00 2001 From: Jimmy Xu Date: Tue, 29 Oct 2024 11:48:27 -0400 Subject: [PATCH 12/16] comments Signed-off-by: Jimmy Xu --- .../command_utils/data_prep/split_eval_set.py | 8 +- .../data_prep/test_split_eval_set.py | 81 +++++++++---------- 2 files changed, 42 insertions(+), 47 deletions(-) diff --git a/llmfoundry/command_utils/data_prep/split_eval_set.py b/llmfoundry/command_utils/data_prep/split_eval_set.py index 51a8c28d7d..544539dc4f 100644 --- a/llmfoundry/command_utils/data_prep/split_eval_set.py +++ b/llmfoundry/command_utils/data_prep/split_eval_set.py @@ -12,7 +12,7 @@ log = logging.getLogger(__name__) -DELTA_JSONL_REGEX = re.compile(r'^tmp-t$') +LOCAL_PATH = 'tmp-t' REMOTE_OBJECT_STORE_FILE_REGEX = re.compile( r'^((s3|oci|gs):\/\/|dbfs:\/Volumes\/)[/a-zA-Z0-9 ()_\-.]+$', ) @@ -27,8 +27,8 @@ def get_dataset_format(data_path_folder: str) -> str: Returns: str: The format of the dataset """ - if DELTA_JSONL_REGEX.match(data_path_folder): - return 'delta' + if data_path_folder == LOCAL_PATH: + return 'local_file' if REMOTE_OBJECT_STORE_FILE_REGEX.match(data_path_folder): return 'remote_object_store' return 'unknown' @@ -61,7 +61,7 @@ def maybe_download_data_as_jsonl( dataset_format = get_dataset_format(data_path_folder) - if dataset_format == 'delta': + if dataset_format == 'local_file': log.info( f'Dataset is converted from Delta table. Using local file {data_path_folder}', ) diff --git a/tests/a_scripts/data_prep/test_split_eval_set.py b/tests/a_scripts/data_prep/test_split_eval_set.py index bc881690c6..ba1db3ac02 100644 --- a/tests/a_scripts/data_prep/test_split_eval_set.py +++ b/tests/a_scripts/data_prep/test_split_eval_set.py @@ -10,7 +10,6 @@ from llmfoundry.command_utils import split_eval_set_from_args, split_examples from llmfoundry.command_utils.data_prep.split_eval_set import ( - DELTA_JSONL_REGEX, REMOTE_OBJECT_STORE_FILE_REGEX, get_dataset_format, ) @@ -23,52 +22,48 @@ DEFAULT_FILE = TMPT_DIR + '/train-00000-of-00001.jsonl' -@pytest.mark.parametrize("test_input: str, expected: bool", [ - ('tmp-t', True), - ('/tmp-t', False), - ('tmp-t-00000-of-00001.jsonl', False), - ('tmp-t-something', False), - ('tmp-t/', False), - ('tmp-t\\', False) -]) -def test_delta_jsonl_regex(test_input: str, expected: bool) -> None: - """Test the regex pattern matches tmp-t exactly.""" - assert bool(DELTA_JSONL_REGEX.match(test_input)) == expected - - - -@pytest.mark.parametrize("test_input: str, expected: bool", [ - ('s3://bucket-name/path/to/file', True), - ('oci://bucket-name/path/to/file', True), - ('gs://bucket-name/path/to/file', True), - ('dbfs:/Volumes/path/to/file', True), - ('s3://bucket-name/path/to/file with spaces', True), - ('https://bucket-name/path/to/file', False), - ('/local/path/to/file', False), - ('s3:/bucket-name/path/to/file', False), - ('s3://bucket-name/path/to/file?', False) -]) -def test_remote_object_store_file_regex(test_input: str, expected: bool) -> None: +@pytest.mark.parametrize( + 'test_input, expected', + [ + ('s3://bucket-name/path/to/file', True), + ('oci://bucket-name/path/to/file', True), + ('gs://bucket-name/path/to/file', True), + ('dbfs:/Volumes/path/to/file', True), + ('s3://bucket-name/path/to/file with spaces', True), + ('https://bucket-name/path/to/file', False), + ('/local/path/to/file', False), + ('s3:/bucket-name/path/to/file', False), + ('s3://bucket-name/path/to/file?', False), + ], +) +def test_remote_object_store_file_regex( + test_input: str, + expected: bool, +) -> None: """Test the regex pattern for remote object store file paths.""" assert bool(REMOTE_OBJECT_STORE_FILE_REGEX.match(test_input)) == expected -@pytest.mark.parametrize("test_input: str, expected: str", [ - # Test delta format - ('tmp-t', 'delta'), - ('tmp-t/', 'unknown'), - - # Test remote object store format - ('s3://bucket-name/path/to/file', 'remote_object_store'), - ('oci://bucket-name/path/to/file', 'remote_object_store'), - ('gs://bucket-name/path/to/file', 'remote_object_store'), - ('dbfs:/Volumes/path/to/file', 'remote_object_store'), - - # Test unknown format - ('/local/path/to/file', 'unknown'), - ('s3:/bucket-name/path/to/file', 'unknown'), - ('dataset:name', 'unknown') -]) +@pytest.mark.parametrize( + 'test_input, expected', + [ + # Test delta format + ('tmp-t', 'local_file'), + ('/tmp-t', 'unknown'), + ('tmp-t/', 'unknown'), + + # Test remote object store format + ('s3://bucket-name/path/to/file', 'remote_object_store'), + ('oci://bucket-name/path/to/file', 'remote_object_store'), + ('gs://bucket-name/path/to/file', 'remote_object_store'), + ('dbfs:/Volumes/path/to/file', 'remote_object_store'), + + # Test unknown format + ('/local/path/to/file', 'unknown'), + ('s3:/bucket-name/path/to/file', 'unknown'), + ('dataset:name', 'unknown'), + ], +) def test_get_dataset_format(test_input: str, expected: str) -> None: """Test the get_dataset_format function.""" assert get_dataset_format(test_input) == expected From a35c6d993aa52acb7309c93783540070682bb1dc Mon Sep 17 00:00:00 2001 From: Jimmy Xu Date: Tue, 29 Oct 2024 14:31:18 -0400 Subject: [PATCH 13/16] make tmpdir Signed-off-by: Jimmy Xu --- .../command_utils/data_prep/split_eval_set.py | 23 ++++++------------- 1 file changed, 7 insertions(+), 16 deletions(-) diff --git a/llmfoundry/command_utils/data_prep/split_eval_set.py b/llmfoundry/command_utils/data_prep/split_eval_set.py index 544539dc4f..1ef06fd054 100644 --- a/llmfoundry/command_utils/data_prep/split_eval_set.py +++ b/llmfoundry/command_utils/data_prep/split_eval_set.py @@ -1,10 +1,10 @@ # Copyright 2024 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -import contextlib import logging import os import re +import tempfile from typing import Optional import composer.utils as utils @@ -34,9 +34,6 @@ def get_dataset_format(data_path_folder: str) -> str: return 'unknown' -TEMP_DIR = 'tmp-split' - - def maybe_download_data_as_jsonl( data_path_folder: str, data_path_split: str, @@ -57,7 +54,7 @@ def maybe_download_data_as_jsonl( Returns: str: Path to the training dataset """ - os.makedirs(TEMP_DIR, exist_ok=True) + TEMP_DIR = tempfile.mkdtemp() dataset_format = get_dataset_format(data_path_folder) @@ -91,17 +88,6 @@ def maybe_download_data_as_jsonl( return data_path -@contextlib.contextmanager -def temp_seed(seed: int): - log.info(f'Setting random seed to {seed}') - state = np.random.get_state() - np.random.seed(seed) - try: - yield - finally: - np.random.set_state(state) - - def split_examples( data_path: str, output_path: str, @@ -132,6 +118,11 @@ def split_examples( # Use a new RNG instance with the provided seed rng = np.random.default_rng(seed) random_numbers = rng.random(total_lines) + + # TODO: Consider using reservoir sampling for large datasets + # Jimmy doesn't think we need to do this right now, since we will + # migrate all of this splitting logic to workflows later anyways, so + # we can do it then sample_indices = set(np.argsort(random_numbers)[:sample_size]) # second pass: sample indices From 7ec198835d6e9cdfca3922952ad9b106c7ecb5ac Mon Sep 17 00:00:00 2001 From: Jimmy Xu Date: Tue, 29 Oct 2024 14:37:14 -0400 Subject: [PATCH 14/16] default to local Signed-off-by: Jimmy Xu --- .../command_utils/data_prep/split_eval_set.py | 39 +++++++------------ .../data_prep/test_split_eval_set.py | 32 +++++++-------- 2 files changed, 26 insertions(+), 45 deletions(-) diff --git a/llmfoundry/command_utils/data_prep/split_eval_set.py b/llmfoundry/command_utils/data_prep/split_eval_set.py index 1ef06fd054..1ffd3b1c8f 100644 --- a/llmfoundry/command_utils/data_prep/split_eval_set.py +++ b/llmfoundry/command_utils/data_prep/split_eval_set.py @@ -12,26 +12,21 @@ log = logging.getLogger(__name__) -LOCAL_PATH = 'tmp-t' REMOTE_OBJECT_STORE_FILE_REGEX = re.compile( r'^((s3|oci|gs):\/\/|dbfs:\/Volumes\/)[/a-zA-Z0-9 ()_\-.]+$', ) -def get_dataset_format(data_path_folder: str) -> str: - """Determine the format of the dataset from the provided data path. +def is_remote_object_store_file(data_path_folder: str) -> bool: + """Check if the provided data path is a remote object store file. Args: data_path_folder (str): Path to the training dataset folder Returns: - str: The format of the dataset + bool: True if the data path is a remote object store file """ - if data_path_folder == LOCAL_PATH: - return 'local_file' - if REMOTE_OBJECT_STORE_FILE_REGEX.match(data_path_folder): - return 'remote_object_store' - return 'unknown' + return REMOTE_OBJECT_STORE_FILE_REGEX.match(data_path_folder) is not None def maybe_download_data_as_jsonl( @@ -43,9 +38,9 @@ def maybe_download_data_as_jsonl( Downloads from remote object store if needed. This function is intended to be invoked by DBX Finetuning. Thus, it assumes the provided data is: - 1. A Delta table converted to JSONL at 'tmp-t/{data_path_split}-00000-of-00001.jsonl` + 1. A JSONL stored as a remote object store file (e.g. S3, OCI, GCS) + 2. A Delta table converted to JSONL at 'tmp-t/{data_path_split}-00000-of-00001.jsonl` using the 'llmfoundry.scripts.convert_delta_to_json.py' script. - 2. A JSONL stored as a remote object store file (e.g. S3, OCI, GCS) Args: data_path_folder (str): Path to the training dataset folder @@ -56,28 +51,20 @@ def maybe_download_data_as_jsonl( """ TEMP_DIR = tempfile.mkdtemp() - dataset_format = get_dataset_format(data_path_folder) - - if dataset_format == 'local_file': - log.info( - f'Dataset is converted from Delta table. Using local file {data_path_folder}', - ) - data_path = os.path.join( - data_path_folder, - f'{data_path_split}-00000-of-00001.jsonl', - ) - - elif dataset_format == 'remote_object_store': + if is_remote_object_store_file(data_path_folder): log.info( f'Downloading dataset from remote object store: {data_path_folder}{data_path_split}.jsonl', ) remote_path = f'{data_path_folder}/{data_path_split}.jsonl' data_path = os.path.join(TEMP_DIR, f'{data_path_split}.jsonl') utils.get_file(remote_path, data_path, overwrite=True) - else: - raise ValueError( - f'Encountered unknown data path format when splitting dataset: {data_path_folder} with split {data_path_split}', + log.info( + f'Dataset is converted from Delta table. Using local file {data_path_folder}', + ) + data_path = os.path.join( + data_path_folder, + f'{data_path_split}-00000-of-00001.jsonl', ) if not os.path.exists(data_path): diff --git a/tests/a_scripts/data_prep/test_split_eval_set.py b/tests/a_scripts/data_prep/test_split_eval_set.py index ba1db3ac02..7e19b29593 100644 --- a/tests/a_scripts/data_prep/test_split_eval_set.py +++ b/tests/a_scripts/data_prep/test_split_eval_set.py @@ -11,7 +11,7 @@ from llmfoundry.command_utils import split_eval_set_from_args, split_examples from llmfoundry.command_utils.data_prep.split_eval_set import ( REMOTE_OBJECT_STORE_FILE_REGEX, - get_dataset_format, + is_remote_object_store_file, ) # Default values @@ -47,26 +47,20 @@ def test_remote_object_store_file_regex( @pytest.mark.parametrize( 'test_input, expected', [ - # Test delta format - ('tmp-t', 'local_file'), - ('/tmp-t', 'unknown'), - ('tmp-t/', 'unknown'), - - # Test remote object store format - ('s3://bucket-name/path/to/file', 'remote_object_store'), - ('oci://bucket-name/path/to/file', 'remote_object_store'), - ('gs://bucket-name/path/to/file', 'remote_object_store'), - ('dbfs:/Volumes/path/to/file', 'remote_object_store'), - - # Test unknown format - ('/local/path/to/file', 'unknown'), - ('s3:/bucket-name/path/to/file', 'unknown'), - ('dataset:name', 'unknown'), + ('s3://bucket-name/path/to/file', True), + ('oci://bucket-name/path/to/file', True), + ('gs://bucket-name/path/to/file', True), + ('dbfs:/Volumes/path/to/file', True), + ('s3://bucket-name/path/to/file with spaces', True), + ('https://bucket-name/path/to/file', False), + ('/local/path/to/dir', False), + ('s3:/bucket-name/path/to/file', False), + ('s3://bucket-name/path/to/file?', False), ], ) -def test_get_dataset_format(test_input: str, expected: str) -> None: - """Test the get_dataset_format function.""" - assert get_dataset_format(test_input) == expected +def test_is_remote_object_store_file(test_input: str, expected: bool) -> None: + """Test the is_remote_object_store_file function.""" + assert is_remote_object_store_file(test_input) == expected def calculate_file_hash(filepath: str) -> str: From c97fbb508b0e7d7f23f9df3151947a1d40559913 Mon Sep 17 00:00:00 2001 From: Jimmy Xu Date: Tue, 29 Oct 2024 14:39:23 -0400 Subject: [PATCH 15/16] remove unknown test Signed-off-by: Jimmy Xu --- tests/a_scripts/data_prep/test_split_eval_set.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/tests/a_scripts/data_prep/test_split_eval_set.py b/tests/a_scripts/data_prep/test_split_eval_set.py index 7e19b29593..3ac8c100f7 100644 --- a/tests/a_scripts/data_prep/test_split_eval_set.py +++ b/tests/a_scripts/data_prep/test_split_eval_set.py @@ -229,13 +229,3 @@ def test_missing_delta_file_error(): OUTPUT_DIR, EVAL_SPLIT_RATIO, ) - - -def test_unknown_file_format_error(): - with pytest.raises(ValueError): - split_eval_set_from_args( - 's3:/path/to/file.jsonl', - 'train', - OUTPUT_DIR, - EVAL_SPLIT_RATIO, - ) From dd0de097b2dd87e2a17379b97289d073a97a0da0 Mon Sep 17 00:00:00 2001 From: Jimmy Xu Date: Tue, 29 Oct 2024 15:48:06 -0400 Subject: [PATCH 16/16] rename Signed-off-by: Jimmy Xu --- llmfoundry/command_utils/__init__.py | 6 +++--- ....py => split_eval_data_from_train_data.py} | 4 ++-- ....py => split_eval_data_from_train_data.py} | 4 ++-- .../data_prep/test_split_eval_set.py | 19 +++++++++++-------- 4 files changed, 18 insertions(+), 15 deletions(-) rename llmfoundry/command_utils/data_prep/{split_eval_set.py => split_eval_data_from_train_data.py} (98%) rename scripts/data_prep/{split_eval_set.py => split_eval_data_from_train_data.py} (92%) diff --git a/llmfoundry/command_utils/__init__.py b/llmfoundry/command_utils/__init__.py index ee535237fd..617f17a642 100644 --- a/llmfoundry/command_utils/__init__.py +++ b/llmfoundry/command_utils/__init__.py @@ -23,8 +23,8 @@ convert_text_to_mds, convert_text_to_mds_from_args, ) -from llmfoundry.command_utils.data_prep.split_eval_set import ( - split_eval_set_from_args, +from llmfoundry.command_utils.data_prep.split_eval_data_from_train_data import ( + split_eval_data_from_train_data_from_args, split_examples, ) from llmfoundry.command_utils.eval import ( @@ -58,6 +58,6 @@ 'convert_text_to_mds_from_args', 'convert_delta_to_json_from_args', 'fetch_DT', - 'split_eval_set_from_args', + 'split_eval_data_from_train_data_from_args', 'split_examples', ] diff --git a/llmfoundry/command_utils/data_prep/split_eval_set.py b/llmfoundry/command_utils/data_prep/split_eval_data_from_train_data.py similarity index 98% rename from llmfoundry/command_utils/data_prep/split_eval_set.py rename to llmfoundry/command_utils/data_prep/split_eval_data_from_train_data.py index 1ffd3b1c8f..10a8537ed6 100644 --- a/llmfoundry/command_utils/data_prep/split_eval_set.py +++ b/llmfoundry/command_utils/data_prep/split_eval_data_from_train_data.py @@ -131,7 +131,7 @@ def split_examples( ) -def split_eval_set_from_args( +def split_eval_data_from_train_data_from_args( data_path_folder: str, data_path_split: str, output_path: str, @@ -139,7 +139,7 @@ def split_eval_set_from_args( max_eval_samples: Optional[int] = None, seed: Optional[int] = None, ) -> None: - """A wrapper for split_eval_set that parses arguments. + """A wrapper for split_examples that parses arguments. Args: data_path_folder (str): Path to the training dataset folder diff --git a/scripts/data_prep/split_eval_set.py b/scripts/data_prep/split_eval_data_from_train_data.py similarity index 92% rename from scripts/data_prep/split_eval_set.py rename to scripts/data_prep/split_eval_data_from_train_data.py index 42aa1c82f0..20e248cdfd 100644 --- a/scripts/data_prep/split_eval_set.py +++ b/scripts/data_prep/split_eval_data_from_train_data.py @@ -3,7 +3,7 @@ from argparse import ArgumentParser -from llmfoundry.command_utils import split_eval_set_from_args +from llmfoundry.command_utils import split_eval_data_from_train_data_from_args if __name__ == '__main__': parser = ArgumentParser( @@ -51,7 +51,7 @@ help='Random seed for splitting the dataset', ) args = parser.parse_args() - split_eval_set_from_args( + split_eval_data_from_train_data_from_args( data_path_folder=args.data_path_folder, data_path_split=args.data_path_split, output_path=args.output_path, diff --git a/tests/a_scripts/data_prep/test_split_eval_set.py b/tests/a_scripts/data_prep/test_split_eval_set.py index 3ac8c100f7..7f9a50b351 100644 --- a/tests/a_scripts/data_prep/test_split_eval_set.py +++ b/tests/a_scripts/data_prep/test_split_eval_set.py @@ -8,8 +8,11 @@ import pytest -from llmfoundry.command_utils import split_eval_set_from_args, split_examples -from llmfoundry.command_utils.data_prep.split_eval_set import ( +from llmfoundry.command_utils import ( + split_eval_data_from_train_data_from_args, + split_examples, +) +from llmfoundry.command_utils.data_prep.split_eval_data_from_train_data import ( REMOTE_OBJECT_STORE_FILE_REGEX, is_remote_object_store_file, ) @@ -96,7 +99,7 @@ def setup_and_teardown_module(): def test_basic_split(): """Test basic functionality on local file.""" output_path = os.path.join(OUTPUT_DIR, 'basic-test') - split_eval_set_from_args( + split_eval_data_from_train_data_from_args( TMPT_DIR, DATA_PATH_SPLIT, output_path, @@ -118,7 +121,7 @@ def test_basic_split_output_exists(): f.write('existing file eval') old_train_hash = calculate_file_hash(train_file) old_eval_hash = calculate_file_hash(eval_file) - split_eval_set_from_args( + split_eval_data_from_train_data_from_args( TMPT_DIR, DATA_PATH_SPLIT, output_path, @@ -132,7 +135,7 @@ def test_max_eval_samples(): """Test case where max_eval_samples < eval_split_ratio * total samples""" output_path = os.path.join(OUTPUT_DIR, 'max-eval-test') max_eval_samples = 50 - split_eval_set_from_args( + split_eval_data_from_train_data_from_args( TMPT_DIR, DATA_PATH_SPLIT, output_path, @@ -146,7 +149,7 @@ def test_max_eval_samples(): def test_eval_split_ratio(): """Test case where max_eval_samples is not used.""" output_path = os.path.join(OUTPUT_DIR, 'eval-split-test') - split_eval_set_from_args( + split_eval_data_from_train_data_from_args( TMPT_DIR, DATA_PATH_SPLIT, output_path, @@ -206,7 +209,7 @@ def test_remote_store_data_split(): 'composer.utils.get_file', side_effect=_mock_get_file, ) as mock_get_file: - split_eval_set_from_args( + split_eval_data_from_train_data_from_args( 'dbfs:/Volumes/test/test/test.jsonl', 'unique-split-name', output_path, @@ -223,7 +226,7 @@ def test_remote_store_data_split(): def test_missing_delta_file_error(): # expects file 'TMPT_DIR/missing-00000-of-00001.jsonl with pytest.raises(FileNotFoundError): - split_eval_set_from_args( + split_eval_data_from_train_data_from_args( TMPT_DIR, 'missing', OUTPUT_DIR,