diff --git a/llmfoundry/command_utils/data_prep/convert_delta_to_mds.py b/llmfoundry/command_utils/data_prep/convert_delta_to_mds.py index 9db204e71a..3acdecfec0 100644 --- a/llmfoundry/command_utils/data_prep/convert_delta_to_mds.py +++ b/llmfoundry/command_utils/data_prep/convert_delta_to_mds.py @@ -1,14 +1,10 @@ # Copyright 2024 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -import json import logging import os import tempfile -from typing import Callable, Optional - -import numpy as np -from streaming import MDSWriter +from typing import Optional from llmfoundry.command_utils.data_prep.convert_delta_to_json import ( _check_imports, @@ -17,59 +13,12 @@ get_columns_info, validate_and_get_cluster_info, ) +from llmfoundry.command_utils.data_prep.convert_finetuning_dataset import \ + convert_finetuning_dataset_from_args logger = logging.getLogger(__name__) -def get_conversion_config( - columns: list[str], - provided_dtypes: Optional[dict], -) -> tuple[dict, Callable]: - """If no dtypes is provided, attempts to infer config based on column names. - - Args: - columns (List[str]): The list of column names. - provided_dtypes (Optional[Dict]): The provided dtypes. - """ - if provided_dtypes is not None: - convert_x = lambda x: { - k: np.array(v, dtype=provided_dtypes.get(k)) for k, v in x.items() - } - return provided_dtypes, convert_x - - if len(columns) != 1: - raise ValueError( - 'Unable to infer dtypes from columns and no dtypes provided.', - ) - - if 'turns' in columns[0]: - logging.info('Identified IFT/CHAT data') - dtypes = { - 'input_ids': 'ndarray', - 'attention_mask': 'ndarray', - 'labels': 'ndarray', - } - convert_x = lambda x: ( - ValueError('More than one turn found') if len(x['turns']) > 1 else { - 'input_ids': np.array(x['turns'][0]['input_ids']), - 'attention_mask': np.array(x['turns'][0]['attention_mask']), - 'labels': np.array(x['turns'][0]['labels']), - } - ) - elif 'concat_tokens' in columns[0]: - logging.info('Identified CPT data') - dtypes = { - 'tokens': 'ndarray', - } - convert_x = lambda x: {'tokens': np.array(x['concat_tokens'])} - else: - raise ValueError( - 'Unable to infer dtypes from columns and no dtypes provided.', - ) - - return dtypes, convert_x - - def convert_delta_to_mds_from_args( delta_table_name: str, mds_output_folder: str, @@ -115,12 +64,6 @@ def convert_delta_to_mds_from_args( ) logger.info(f'Columns: {columns}') - dtypes, convert_x = get_conversion_config(columns, dtypes) - - compression = 'zstd:7' - hashes = ['sha1'] - limit = '10mb' - logging.info(f'Fetching data from Delta Table {delta_table_name}...') with tempfile.TemporaryDirectory() as json_out_folder: @@ -142,19 +85,24 @@ def convert_delta_to_mds_from_args( except Exception as e: logger.error(f'Error fetching data from Delta Table: {e}') raise e - with MDSWriter( - out=mds_output_folder, - columns=dtypes, - compression=compression, - hashes=hashes, - size_limit=limit, - ) as out: - try: - with open(json_full_filepath, 'r') as f: - for line in f: - out.write(convert_x(json.loads(line))) - except FileNotFoundError as e: - logger.error(f'JSON output file not found: {e}') - raise e - logging.info(f'Wrote to MDS at {mds_output_folder}') + convert_finetuning_dataset_from_args( + dataset='json', + data_subset=None, + splits=[''], + preprocessor=None, + data_files=[json_full_filepath], + skip_preprocessing=True, + out_root=mds_output_folder, + local=None, + compression='zstd:7', + num_workers=processes, + tokenizer=None, + tokenizer_kwargs=None, + max_seq_len=-1, + target_prompts='', + target_responses='', + encoder_decoder=False, + ) + + logging.info(f'Wrote to MDS at {mds_output_folder}')