From b04651f3c817fdb774243fc741bd5838c5b76241 Mon Sep 17 00:00:00 2001 From: Matthew Ding Date: Mon, 16 Sep 2024 00:58:53 -0700 Subject: [PATCH] error handling and testing --- llmfoundry/command_utils/__init__.py | 6 +- .../command_utils/data_prep/split_eval_set.py | 38 ++-- .../data_prep/test_split_eval_set.py | 163 ++++++++++++++++++ 3 files changed, 183 insertions(+), 24 deletions(-) create mode 100644 tests/a_scripts/data_prep/test_split_eval_set.py diff --git a/llmfoundry/command_utils/__init__.py b/llmfoundry/command_utils/__init__.py index 5407b723cc..8757f3b1bc 100644 --- a/llmfoundry/command_utils/__init__.py +++ b/llmfoundry/command_utils/__init__.py @@ -20,7 +20,10 @@ convert_text_to_mds, convert_text_to_mds_from_args, ) -from llmfoundry.command_utils.data_prep.split_eval_set import split_eval_set_from_args +from llmfoundry.command_utils.data_prep.split_eval_set import ( + split_eval_set_from_args, + split_examples, +) from llmfoundry.command_utils.eval import ( eval_from_yaml, evaluate, @@ -52,4 +55,5 @@ "convert_delta_to_json_from_args", "fetch_DT", "split_eval_set_from_args", + "split_examples", ] diff --git a/llmfoundry/command_utils/data_prep/split_eval_set.py b/llmfoundry/command_utils/data_prep/split_eval_set.py index f6afc8722d..b4b150f81f 100644 --- a/llmfoundry/command_utils/data_prep/split_eval_set.py +++ b/llmfoundry/command_utils/data_prep/split_eval_set.py @@ -10,7 +10,7 @@ import numpy as np from typing import Optional -from composer.utils import get_file +import composer.utils as utils from llmfoundry.data.finetuning.tasks import maybe_safe_download_hf_data @@ -24,11 +24,6 @@ log = logging.getLogger(__name__) -import sys - -log.setLevel(logging.DEBUG) -log.addHandler(logging.StreamHandler(sys.stdout)) - def maybe_download_data_as_json(data_path_folder: str, data_path_split: str) -> str: """ @@ -51,22 +46,16 @@ def maybe_download_data_as_json(data_path_folder: str, data_path_split: str) -> os.makedirs(TEMP_DIR, exist_ok=True) if DELTA_JSONL_REGEX.match(data_path_folder): + log.info(f"Dataset is converted from Delta table. Using local file {data_path_folder}") data_path = os.path.join(data_path_folder, f"{data_path_split}-00000-of-00001.jsonl") - if not os.path.exists(data_path): - # TODO: error handling - raise FileNotFoundError(f"File {data_path} does not exist.") - if REMOTE_OBJECT_STORE_FILE_REGEX.match(data_path_folder): + elif REMOTE_OBJECT_STORE_FILE_REGEX.match(data_path_folder): log.info( f"Downloading dataset from remote object store: {data_path_folder}{data_path_split}.jsonl" ) remote_path = f"{data_path_folder}/{data_path_split}.jsonl" data_path = os.path.join(TEMP_DIR, f"{data_path_split}.jsonl") - try: - get_file(remote_path, data_path, overwrite=True) - except FileNotFoundError as e: - # TODO: error handling - raise e + utils.get_file(remote_path, data_path, overwrite=True) elif HF_REGEX.match(data_path_folder): log.info( @@ -85,20 +74,21 @@ def maybe_download_data_as_json(data_path_folder: str, data_path_split: str) -> f.write(json.dumps(example) + "\n") else: - # TODO: error handling raise ValueError( - f"Unrecognized data_path_folder: {data_path_folder}. Must be a Delta table, remote object store file, or Hugging Face dataset." + f"Encountered unknown data path format when splitting dataset: {data_path_folder} with split {data_path_split}" ) if not os.path.exists(data_path): - # TODO: error handling - raise FileNotFoundError(f"File {data_path} does not exist.") + raise FileNotFoundError( + f"Expected dataset file at {data_path} for splitting, but it does not exist." + ) return data_path @contextlib.contextmanager def temp_seed(seed: int): + log.info(f"Setting random seed to {seed}") state = np.random.get_state() np.random.seed(seed) try: @@ -107,11 +97,11 @@ def temp_seed(seed: int): np.random.set_state(state) -def _split_examples( +def split_examples( data_path: str, output_path: str, eval_split_ratio: float, - max_eval_samples: Optional[int], + max_eval_samples: Optional[int] = None, seed: Optional[int] = None, ) -> None: """ @@ -119,10 +109,13 @@ def _split_examples( Args: data_path (str): Path to the training dataset (local jsonl file) + output_path (str): Directory to save the split dataset eval_split_ratio (float): Ratio of the dataset to use for evaluation. The remainder will be used for training max_eval_samples (int): Maximum number of samples to include in the eval set. If None, all eval_split_ratio * train_dataset_size samples will be used seed (int): Random seed for splitting the dataset """ + os.makedirs(output_path, exist_ok=True) + # first pass: count total number of lines and determine sample size total_lines = 0 with open(data_path, "r") as infile: @@ -170,6 +163,5 @@ def split_eval_set_from_args( max_eval_samples (int): Maximum number of samples to include in the eval set. If None, all eval_split_ratio * train_dataset_size samples will be used seed (int): Random seed for splitting the dataset """ - os.makedirs(output_path, exist_ok=True) data_path = maybe_download_data_as_json(data_path_folder, data_path_split) - _split_examples(data_path, output_path, eval_split_ratio, max_eval_samples, seed) + split_examples(data_path, output_path, eval_split_ratio, max_eval_samples, seed) diff --git a/tests/a_scripts/data_prep/test_split_eval_set.py b/tests/a_scripts/data_prep/test_split_eval_set.py new file mode 100644 index 0000000000..a1b80b91cd --- /dev/null +++ b/tests/a_scripts/data_prep/test_split_eval_set.py @@ -0,0 +1,163 @@ +import os +import json +import pytest +import hashlib +from unittest.mock import patch + +from llmfoundry.command_utils import split_eval_set_from_args, split_examples + +# Default values +OUTPUT_DIR = "tmp-split" +TMPT_DIR = "tmp-t" +DATA_PATH_SPLIT = "train" +EVAL_SPLIT_RATIO = 0.1 +DEFAULT_FILE = TMPT_DIR + "/train-00000-of-00001.jsonl" + + +def calculate_file_hash(filepath: str) -> str: + with open(filepath, "rb") as f: + file_hash = hashlib.sha256(f.read()).hexdigest() + return file_hash + + +def count_lines(filepath: str) -> int: + with open(filepath, "r") as f: + return sum(1 for _ in f) + + +@pytest.fixture(scope="module", autouse=True) +def setup_and_teardown_module(): + # Setup: create local testing file + os.makedirs(TMPT_DIR, exist_ok=True) + with open(DEFAULT_FILE, "w") as f: + for i in range(1000): + f.write(json.dumps({"prompt": "hello world " + str(i), "response": "hi you!"}) + "\n") + yield + + # Teardown: clean up output and tmp directories + os.system(f"rm -rf {OUTPUT_DIR}") + os.system(f"rm -rf {TMPT_DIR}") + + +def test_basic_split(): + """Test basic functionality on local file""" + output_path = os.path.join(OUTPUT_DIR, "basic-test") + split_eval_set_from_args(TMPT_DIR, DATA_PATH_SPLIT, output_path, EVAL_SPLIT_RATIO) + assert os.path.isfile(os.path.join(output_path, "train.jsonl")) + assert os.path.isfile(os.path.join(output_path, "eval.jsonl")) + + +def test_basic_split_output_exists(): + """Test that split overwrites existing files in output directory""" + output_path = os.path.join(OUTPUT_DIR, "basic-test") + os.makedirs(output_path, exist_ok=True) + train_file = os.path.join(output_path, "train.jsonl") + eval_file = os.path.join(output_path, "eval.jsonl") + with open(train_file, "w") as f: + f.write("existing file train") + with open(eval_file, "w") as f: + f.write("existing file eval") + old_train_hash = calculate_file_hash(train_file) + old_eval_hash = calculate_file_hash(eval_file) + split_eval_set_from_args( + TMPT_DIR, + DATA_PATH_SPLIT, + output_path, + EVAL_SPLIT_RATIO, + ) + assert calculate_file_hash(train_file) != old_train_hash + assert calculate_file_hash(eval_file) != old_eval_hash + + +def test_max_eval_samples(): + """Test case where max_eval_samples < eval_split_ratio * total samples""" + output_path = os.path.join(OUTPUT_DIR, "max-eval-test") + max_eval_samples = 50 + split_eval_set_from_args( + TMPT_DIR, + DATA_PATH_SPLIT, + output_path, + EVAL_SPLIT_RATIO, + max_eval_samples, + ) + eval_lines = count_lines(os.path.join(output_path, "eval.jsonl")) + assert eval_lines == max_eval_samples + + +def test_eval_split_ratio(): + """Test case where max_eval_samples is not used""" + output_path = os.path.join(OUTPUT_DIR, "eval-split-test") + split_eval_set_from_args(TMPT_DIR, DATA_PATH_SPLIT, output_path, EVAL_SPLIT_RATIO) + original_data_lines = count_lines(DEFAULT_FILE) + eval_lines = count_lines(os.path.join(output_path, "eval.jsonl")) + assert abs(eval_lines - EVAL_SPLIT_RATIO * original_data_lines) < 1 # allow for rounding errors + + +def test_seed_consistency(): + """Test if the same seed generates consistent splits""" + output_path_1 = os.path.join(OUTPUT_DIR, "seed-test-1") + output_path_2 = os.path.join(OUTPUT_DIR, "seed-test-2") + split_examples(DEFAULT_FILE, output_path_1, EVAL_SPLIT_RATIO, seed=12345) + split_examples(DEFAULT_FILE, output_path_2, EVAL_SPLIT_RATIO, seed=12345) + train_hash_1 = calculate_file_hash(os.path.join(output_path_1, "train.jsonl")) + train_hash_2 = calculate_file_hash(os.path.join(output_path_2, "train.jsonl")) + eval_hash_1 = calculate_file_hash(os.path.join(output_path_1, "eval.jsonl")) + eval_hash_2 = calculate_file_hash(os.path.join(output_path_2, "eval.jsonl")) + + assert train_hash_1 == train_hash_2 + assert eval_hash_1 == eval_hash_2 + + output_path_3 = os.path.join(OUTPUT_DIR, "seed-test-3") + split_examples(DEFAULT_FILE, output_path_3, EVAL_SPLIT_RATIO, seed=54321) + train_hash_3 = calculate_file_hash(os.path.join(output_path_3, "train.jsonl")) + eval_hash_3 = calculate_file_hash(os.path.join(output_path_3, "eval.jsonl")) + + assert train_hash_1 != train_hash_3 + assert eval_hash_1 != eval_hash_3 + + +def test_hf_data_split(): + """Test splitting a dataset from Hugging Face""" + output_path = os.path.join(OUTPUT_DIR, "hf-split-test") + split_eval_set_from_args( + "databricks/databricks-dolly-15k", "train", output_path, EVAL_SPLIT_RATIO + ) + assert os.path.isfile(os.path.join(output_path, "train.jsonl")) + assert os.path.isfile(os.path.join(output_path, "eval.jsonl")) + assert count_lines(os.path.join(output_path, "train.jsonl")) > 0 + assert count_lines(os.path.join(output_path, "eval.jsonl")) > 0 + + +def _mock_get_file(remote_path: str, data_path: str, overwrite: bool): + with open(data_path, "w") as f: + for i in range(1000): + f.write(json.dumps({"prompt": "hello world " + str(i), "response": "hi you!"}) + "\n") + + +def test_remote_store_data_split(): + """Test splitting a dataset from a remote store""" + output_path = os.path.join(OUTPUT_DIR, "remote-split-test") + with patch("composer.utils.get_file", side_effect=_mock_get_file) as mock_get_file: + split_eval_set_from_args( + "dbfs:/Volumes/test/test/test.jsonl", + "unique-split-name", + output_path, + EVAL_SPLIT_RATIO, + ) + mock_get_file.assert_called() + + assert os.path.isfile(os.path.join(output_path, "train.jsonl")) + assert os.path.isfile(os.path.join(output_path, "eval.jsonl")) + assert count_lines(os.path.join(output_path, "train.jsonl")) > 0 + assert count_lines(os.path.join(output_path, "eval.jsonl")) > 0 + + +def test_missing_delta_file_error(): + # expects file 'TMPT_DIR/missing-00000-of-00001.jsonl + with pytest.raises(FileNotFoundError): + split_eval_set_from_args(TMPT_DIR, "missing", OUTPUT_DIR, EVAL_SPLIT_RATIO) + + +def test_unknown_file_format_error(): + with pytest.raises(ValueError): + split_eval_set_from_args("s3:/path/to/file.jsonl", "train", OUTPUT_DIR, EVAL_SPLIT_RATIO)