Skip to content

Commit

Permalink
reuse convert_ft_dataset fn
Browse files Browse the repository at this point in the history
  • Loading branch information
mattyding committed Dec 7, 2024
1 parent b5bf28c commit 9372f48
Showing 1 changed file with 23 additions and 75 deletions.
98 changes: 23 additions & 75 deletions llmfoundry/command_utils/data_prep/convert_delta_to_mds.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import json
import logging
import os
import tempfile
from typing import Callable, Optional

import numpy as np
from streaming import MDSWriter
from typing import Optional

from llmfoundry.command_utils.data_prep.convert_delta_to_json import (
_check_imports,
Expand All @@ -17,59 +13,12 @@
get_columns_info,
validate_and_get_cluster_info,
)
from llmfoundry.command_utils.data_prep.convert_finetuning_dataset import \
convert_finetuning_dataset_from_args

logger = logging.getLogger(__name__)


def get_conversion_config(
columns: list[str],
provided_dtypes: Optional[dict],
) -> tuple[dict, Callable]:
"""If no dtypes is provided, attempts to infer config based on column names.
Args:
columns (List[str]): The list of column names.
provided_dtypes (Optional[Dict]): The provided dtypes.
"""
if provided_dtypes is not None:
convert_x = lambda x: {
k: np.array(v, dtype=provided_dtypes.get(k)) for k, v in x.items()
}
return provided_dtypes, convert_x

if len(columns) != 1:
raise ValueError(
'Unable to infer dtypes from columns and no dtypes provided.',
)

if 'turns' in columns[0]:
logging.info('Identified IFT/CHAT data')
dtypes = {
'input_ids': 'ndarray',
'attention_mask': 'ndarray',
'labels': 'ndarray',
}
convert_x = lambda x: (
ValueError('More than one turn found') if len(x['turns']) > 1 else {
'input_ids': np.array(x['turns'][0]['input_ids']),
'attention_mask': np.array(x['turns'][0]['attention_mask']),
'labels': np.array(x['turns'][0]['labels']),
}
)
elif 'concat_tokens' in columns[0]:
logging.info('Identified CPT data')
dtypes = {
'tokens': 'ndarray',
}
convert_x = lambda x: {'tokens': np.array(x['concat_tokens'])}
else:
raise ValueError(
'Unable to infer dtypes from columns and no dtypes provided.',
)

return dtypes, convert_x


def convert_delta_to_mds_from_args(
delta_table_name: str,
mds_output_folder: str,
Expand Down Expand Up @@ -115,12 +64,6 @@ def convert_delta_to_mds_from_args(
)
logger.info(f'Columns: {columns}')

dtypes, convert_x = get_conversion_config(columns, dtypes)

compression = 'zstd:7'
hashes = ['sha1']
limit = '10mb'

logging.info(f'Fetching data from Delta Table {delta_table_name}...')

with tempfile.TemporaryDirectory() as json_out_folder:
Expand All @@ -142,19 +85,24 @@ def convert_delta_to_mds_from_args(
except Exception as e:
logger.error(f'Error fetching data from Delta Table: {e}')
raise e
with MDSWriter(
out=mds_output_folder,
columns=dtypes,
compression=compression,
hashes=hashes,
size_limit=limit,
) as out:
try:
with open(json_full_filepath, 'r') as f:
for line in f:
out.write(convert_x(json.loads(line)))
except FileNotFoundError as e:
logger.error(f'JSON output file not found: {e}')
raise e

logging.info(f'Wrote to MDS at {mds_output_folder}')
convert_finetuning_dataset_from_args(
dataset='json',
data_subset=None,
splits=[''],
preprocessor=None,
data_files=[json_full_filepath],
skip_preprocessing=True,
out_root=mds_output_folder,
local=None,
compression='zstd:7',
num_workers=processes,
tokenizer=None,
tokenizer_kwargs=None,
max_seq_len=-1,
target_prompts='',
target_responses='',
encoder_decoder=False,
)

logging.info(f'Wrote to MDS at {mds_output_folder}')

0 comments on commit 9372f48

Please sign in to comment.