Skip to content

Commit

Permalink
rename
Browse files Browse the repository at this point in the history
Signed-off-by: Jimmy Xu <[email protected]>
  • Loading branch information
jimmyxu-db committed Oct 29, 2024
1 parent c97fbb5 commit dd0de09
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 15 deletions.
6 changes: 3 additions & 3 deletions llmfoundry/command_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
convert_text_to_mds,
convert_text_to_mds_from_args,
)
from llmfoundry.command_utils.data_prep.split_eval_set import (
split_eval_set_from_args,
from llmfoundry.command_utils.data_prep.split_eval_data_from_train_data import (
split_eval_data_from_train_data_from_args,
split_examples,
)
from llmfoundry.command_utils.eval import (
Expand Down Expand Up @@ -58,6 +58,6 @@
'convert_text_to_mds_from_args',
'convert_delta_to_json_from_args',
'fetch_DT',
'split_eval_set_from_args',
'split_eval_data_from_train_data_from_args',
'split_examples',
]
Original file line number Diff line number Diff line change
Expand Up @@ -131,15 +131,15 @@ def split_examples(
)


def split_eval_set_from_args(
def split_eval_data_from_train_data_from_args(
data_path_folder: str,
data_path_split: str,
output_path: str,
eval_split_ratio: float,
max_eval_samples: Optional[int] = None,
seed: Optional[int] = None,
) -> None:
"""A wrapper for split_eval_set that parses arguments.
"""A wrapper for split_examples that parses arguments.
Args:
data_path_folder (str): Path to the training dataset folder
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from argparse import ArgumentParser

from llmfoundry.command_utils import split_eval_set_from_args
from llmfoundry.command_utils import split_eval_data_from_train_data_from_args

if __name__ == '__main__':
parser = ArgumentParser(
Expand Down Expand Up @@ -51,7 +51,7 @@
help='Random seed for splitting the dataset',
)
args = parser.parse_args()
split_eval_set_from_args(
split_eval_data_from_train_data_from_args(
data_path_folder=args.data_path_folder,
data_path_split=args.data_path_split,
output_path=args.output_path,
Expand Down
19 changes: 11 additions & 8 deletions tests/a_scripts/data_prep/test_split_eval_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@

import pytest

from llmfoundry.command_utils import split_eval_set_from_args, split_examples
from llmfoundry.command_utils.data_prep.split_eval_set import (
from llmfoundry.command_utils import (
split_eval_data_from_train_data_from_args,
split_examples,
)
from llmfoundry.command_utils.data_prep.split_eval_data_from_train_data import (
REMOTE_OBJECT_STORE_FILE_REGEX,
is_remote_object_store_file,
)
Expand Down Expand Up @@ -96,7 +99,7 @@ def setup_and_teardown_module():
def test_basic_split():
"""Test basic functionality on local file."""
output_path = os.path.join(OUTPUT_DIR, 'basic-test')
split_eval_set_from_args(
split_eval_data_from_train_data_from_args(
TMPT_DIR,
DATA_PATH_SPLIT,
output_path,
Expand All @@ -118,7 +121,7 @@ def test_basic_split_output_exists():
f.write('existing file eval')
old_train_hash = calculate_file_hash(train_file)
old_eval_hash = calculate_file_hash(eval_file)
split_eval_set_from_args(
split_eval_data_from_train_data_from_args(
TMPT_DIR,
DATA_PATH_SPLIT,
output_path,
Expand All @@ -132,7 +135,7 @@ def test_max_eval_samples():
"""Test case where max_eval_samples < eval_split_ratio * total samples"""
output_path = os.path.join(OUTPUT_DIR, 'max-eval-test')
max_eval_samples = 50
split_eval_set_from_args(
split_eval_data_from_train_data_from_args(
TMPT_DIR,
DATA_PATH_SPLIT,
output_path,
Expand All @@ -146,7 +149,7 @@ def test_max_eval_samples():
def test_eval_split_ratio():
"""Test case where max_eval_samples is not used."""
output_path = os.path.join(OUTPUT_DIR, 'eval-split-test')
split_eval_set_from_args(
split_eval_data_from_train_data_from_args(
TMPT_DIR,
DATA_PATH_SPLIT,
output_path,
Expand Down Expand Up @@ -206,7 +209,7 @@ def test_remote_store_data_split():
'composer.utils.get_file',
side_effect=_mock_get_file,
) as mock_get_file:
split_eval_set_from_args(
split_eval_data_from_train_data_from_args(
'dbfs:/Volumes/test/test/test.jsonl',
'unique-split-name',
output_path,
Expand All @@ -223,7 +226,7 @@ def test_remote_store_data_split():
def test_missing_delta_file_error():
# expects file 'TMPT_DIR/missing-00000-of-00001.jsonl
with pytest.raises(FileNotFoundError):
split_eval_set_from_args(
split_eval_data_from_train_data_from_args(
TMPT_DIR,
'missing',
OUTPUT_DIR,
Expand Down

0 comments on commit dd0de09

Please sign in to comment.