Skip to content

Commit

Permalink
error handling and testing
Browse files Browse the repository at this point in the history
  • Loading branch information
mattyding authored and jimmyxu-db committed Oct 22, 2024
1 parent 87d7a4c commit b04651f
Show file tree
Hide file tree
Showing 3 changed files with 183 additions and 24 deletions.
6 changes: 5 additions & 1 deletion llmfoundry/command_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
convert_text_to_mds,
convert_text_to_mds_from_args,
)
from llmfoundry.command_utils.data_prep.split_eval_set import split_eval_set_from_args
from llmfoundry.command_utils.data_prep.split_eval_set import (
split_eval_set_from_args,
split_examples,
)
from llmfoundry.command_utils.eval import (
eval_from_yaml,
evaluate,
Expand Down Expand Up @@ -52,4 +55,5 @@
"convert_delta_to_json_from_args",
"fetch_DT",
"split_eval_set_from_args",
"split_examples",
]
38 changes: 15 additions & 23 deletions llmfoundry/command_utils/data_prep/split_eval_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import numpy as np
from typing import Optional

from composer.utils import get_file
import composer.utils as utils
from llmfoundry.data.finetuning.tasks import maybe_safe_download_hf_data


Expand All @@ -24,11 +24,6 @@

log = logging.getLogger(__name__)

import sys

log.setLevel(logging.DEBUG)
log.addHandler(logging.StreamHandler(sys.stdout))


def maybe_download_data_as_json(data_path_folder: str, data_path_split: str) -> str:
"""
Expand All @@ -51,22 +46,16 @@ def maybe_download_data_as_json(data_path_folder: str, data_path_split: str) ->
os.makedirs(TEMP_DIR, exist_ok=True)

if DELTA_JSONL_REGEX.match(data_path_folder):
log.info(f"Dataset is converted from Delta table. Using local file {data_path_folder}")
data_path = os.path.join(data_path_folder, f"{data_path_split}-00000-of-00001.jsonl")
if not os.path.exists(data_path):
# TODO: error handling
raise FileNotFoundError(f"File {data_path} does not exist.")

if REMOTE_OBJECT_STORE_FILE_REGEX.match(data_path_folder):
elif REMOTE_OBJECT_STORE_FILE_REGEX.match(data_path_folder):
log.info(
f"Downloading dataset from remote object store: {data_path_folder}{data_path_split}.jsonl"
)
remote_path = f"{data_path_folder}/{data_path_split}.jsonl"
data_path = os.path.join(TEMP_DIR, f"{data_path_split}.jsonl")
try:
get_file(remote_path, data_path, overwrite=True)
except FileNotFoundError as e:
# TODO: error handling
raise e
utils.get_file(remote_path, data_path, overwrite=True)

elif HF_REGEX.match(data_path_folder):
log.info(
Expand All @@ -85,20 +74,21 @@ def maybe_download_data_as_json(data_path_folder: str, data_path_split: str) ->
f.write(json.dumps(example) + "\n")

else:
# TODO: error handling
raise ValueError(
f"Unrecognized data_path_folder: {data_path_folder}. Must be a Delta table, remote object store file, or Hugging Face dataset."
f"Encountered unknown data path format when splitting dataset: {data_path_folder} with split {data_path_split}"
)

if not os.path.exists(data_path):
# TODO: error handling
raise FileNotFoundError(f"File {data_path} does not exist.")
raise FileNotFoundError(
f"Expected dataset file at {data_path} for splitting, but it does not exist."
)

return data_path


@contextlib.contextmanager
def temp_seed(seed: int):
log.info(f"Setting random seed to {seed}")
state = np.random.get_state()
np.random.seed(seed)
try:
Expand All @@ -107,22 +97,25 @@ def temp_seed(seed: int):
np.random.set_state(state)


def _split_examples(
def split_examples(
data_path: str,
output_path: str,
eval_split_ratio: float,
max_eval_samples: Optional[int],
max_eval_samples: Optional[int] = None,
seed: Optional[int] = None,
) -> None:
"""
Splits the dataset into training and evaluation sets.
Args:
data_path (str): Path to the training dataset (local jsonl file)
output_path (str): Directory to save the split dataset
eval_split_ratio (float): Ratio of the dataset to use for evaluation. The remainder will be used for training
max_eval_samples (int): Maximum number of samples to include in the eval set. If None, all eval_split_ratio * train_dataset_size samples will be used
seed (int): Random seed for splitting the dataset
"""
os.makedirs(output_path, exist_ok=True)

# first pass: count total number of lines and determine sample size
total_lines = 0
with open(data_path, "r") as infile:
Expand Down Expand Up @@ -170,6 +163,5 @@ def split_eval_set_from_args(
max_eval_samples (int): Maximum number of samples to include in the eval set. If None, all eval_split_ratio * train_dataset_size samples will be used
seed (int): Random seed for splitting the dataset
"""
os.makedirs(output_path, exist_ok=True)
data_path = maybe_download_data_as_json(data_path_folder, data_path_split)
_split_examples(data_path, output_path, eval_split_ratio, max_eval_samples, seed)
split_examples(data_path, output_path, eval_split_ratio, max_eval_samples, seed)
163 changes: 163 additions & 0 deletions tests/a_scripts/data_prep/test_split_eval_set.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
import os
import json
import pytest
import hashlib
from unittest.mock import patch

from llmfoundry.command_utils import split_eval_set_from_args, split_examples

# Default values
OUTPUT_DIR = "tmp-split"
TMPT_DIR = "tmp-t"
DATA_PATH_SPLIT = "train"
EVAL_SPLIT_RATIO = 0.1
DEFAULT_FILE = TMPT_DIR + "/train-00000-of-00001.jsonl"


def calculate_file_hash(filepath: str) -> str:
with open(filepath, "rb") as f:
file_hash = hashlib.sha256(f.read()).hexdigest()
return file_hash


def count_lines(filepath: str) -> int:
with open(filepath, "r") as f:
return sum(1 for _ in f)


@pytest.fixture(scope="module", autouse=True)
def setup_and_teardown_module():
# Setup: create local testing file
os.makedirs(TMPT_DIR, exist_ok=True)
with open(DEFAULT_FILE, "w") as f:
for i in range(1000):
f.write(json.dumps({"prompt": "hello world " + str(i), "response": "hi you!"}) + "\n")
yield

# Teardown: clean up output and tmp directories
os.system(f"rm -rf {OUTPUT_DIR}")
os.system(f"rm -rf {TMPT_DIR}")


def test_basic_split():
"""Test basic functionality on local file"""
output_path = os.path.join(OUTPUT_DIR, "basic-test")
split_eval_set_from_args(TMPT_DIR, DATA_PATH_SPLIT, output_path, EVAL_SPLIT_RATIO)
assert os.path.isfile(os.path.join(output_path, "train.jsonl"))
assert os.path.isfile(os.path.join(output_path, "eval.jsonl"))


def test_basic_split_output_exists():
"""Test that split overwrites existing files in output directory"""
output_path = os.path.join(OUTPUT_DIR, "basic-test")
os.makedirs(output_path, exist_ok=True)
train_file = os.path.join(output_path, "train.jsonl")
eval_file = os.path.join(output_path, "eval.jsonl")
with open(train_file, "w") as f:
f.write("existing file train")
with open(eval_file, "w") as f:
f.write("existing file eval")
old_train_hash = calculate_file_hash(train_file)
old_eval_hash = calculate_file_hash(eval_file)
split_eval_set_from_args(
TMPT_DIR,
DATA_PATH_SPLIT,
output_path,
EVAL_SPLIT_RATIO,
)
assert calculate_file_hash(train_file) != old_train_hash
assert calculate_file_hash(eval_file) != old_eval_hash


def test_max_eval_samples():
"""Test case where max_eval_samples < eval_split_ratio * total samples"""
output_path = os.path.join(OUTPUT_DIR, "max-eval-test")
max_eval_samples = 50
split_eval_set_from_args(
TMPT_DIR,
DATA_PATH_SPLIT,
output_path,
EVAL_SPLIT_RATIO,
max_eval_samples,
)
eval_lines = count_lines(os.path.join(output_path, "eval.jsonl"))
assert eval_lines == max_eval_samples


def test_eval_split_ratio():
"""Test case where max_eval_samples is not used"""
output_path = os.path.join(OUTPUT_DIR, "eval-split-test")
split_eval_set_from_args(TMPT_DIR, DATA_PATH_SPLIT, output_path, EVAL_SPLIT_RATIO)
original_data_lines = count_lines(DEFAULT_FILE)
eval_lines = count_lines(os.path.join(output_path, "eval.jsonl"))
assert abs(eval_lines - EVAL_SPLIT_RATIO * original_data_lines) < 1 # allow for rounding errors


def test_seed_consistency():
"""Test if the same seed generates consistent splits"""
output_path_1 = os.path.join(OUTPUT_DIR, "seed-test-1")
output_path_2 = os.path.join(OUTPUT_DIR, "seed-test-2")
split_examples(DEFAULT_FILE, output_path_1, EVAL_SPLIT_RATIO, seed=12345)
split_examples(DEFAULT_FILE, output_path_2, EVAL_SPLIT_RATIO, seed=12345)
train_hash_1 = calculate_file_hash(os.path.join(output_path_1, "train.jsonl"))
train_hash_2 = calculate_file_hash(os.path.join(output_path_2, "train.jsonl"))
eval_hash_1 = calculate_file_hash(os.path.join(output_path_1, "eval.jsonl"))
eval_hash_2 = calculate_file_hash(os.path.join(output_path_2, "eval.jsonl"))

assert train_hash_1 == train_hash_2
assert eval_hash_1 == eval_hash_2

output_path_3 = os.path.join(OUTPUT_DIR, "seed-test-3")
split_examples(DEFAULT_FILE, output_path_3, EVAL_SPLIT_RATIO, seed=54321)
train_hash_3 = calculate_file_hash(os.path.join(output_path_3, "train.jsonl"))
eval_hash_3 = calculate_file_hash(os.path.join(output_path_3, "eval.jsonl"))

assert train_hash_1 != train_hash_3
assert eval_hash_1 != eval_hash_3


def test_hf_data_split():
"""Test splitting a dataset from Hugging Face"""
output_path = os.path.join(OUTPUT_DIR, "hf-split-test")
split_eval_set_from_args(
"databricks/databricks-dolly-15k", "train", output_path, EVAL_SPLIT_RATIO
)
assert os.path.isfile(os.path.join(output_path, "train.jsonl"))
assert os.path.isfile(os.path.join(output_path, "eval.jsonl"))
assert count_lines(os.path.join(output_path, "train.jsonl")) > 0
assert count_lines(os.path.join(output_path, "eval.jsonl")) > 0


def _mock_get_file(remote_path: str, data_path: str, overwrite: bool):
with open(data_path, "w") as f:
for i in range(1000):
f.write(json.dumps({"prompt": "hello world " + str(i), "response": "hi you!"}) + "\n")


def test_remote_store_data_split():
"""Test splitting a dataset from a remote store"""
output_path = os.path.join(OUTPUT_DIR, "remote-split-test")
with patch("composer.utils.get_file", side_effect=_mock_get_file) as mock_get_file:
split_eval_set_from_args(
"dbfs:/Volumes/test/test/test.jsonl",
"unique-split-name",
output_path,
EVAL_SPLIT_RATIO,
)
mock_get_file.assert_called()

assert os.path.isfile(os.path.join(output_path, "train.jsonl"))
assert os.path.isfile(os.path.join(output_path, "eval.jsonl"))
assert count_lines(os.path.join(output_path, "train.jsonl")) > 0
assert count_lines(os.path.join(output_path, "eval.jsonl")) > 0


def test_missing_delta_file_error():
# expects file 'TMPT_DIR/missing-00000-of-00001.jsonl
with pytest.raises(FileNotFoundError):
split_eval_set_from_args(TMPT_DIR, "missing", OUTPUT_DIR, EVAL_SPLIT_RATIO)


def test_unknown_file_format_error():
with pytest.raises(ValueError):
split_eval_set_from_args("s3:/path/to/file.jsonl", "train", OUTPUT_DIR, EVAL_SPLIT_RATIO)

0 comments on commit b04651f

Please sign in to comment.