Skip to content

Commit

Permalink
assume single turn input
Browse files Browse the repository at this point in the history
  • Loading branch information
mattyding committed Dec 6, 2024
1 parent 19bf0a4 commit b5bf28c
Showing 1 changed file with 8 additions and 15 deletions.
23 changes: 8 additions & 15 deletions llmfoundry/command_utils/data_prep/convert_delta_to_mds.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,26 +43,19 @@ def get_conversion_config(
)

if 'turns' in columns[0]:
logging.info('Identified IFT data')
logging.info('Identified IFT/CHAT data')
dtypes = {
'input_ids': 'ndarray',
'attention_mask': 'ndarray',
'labels': 'ndarray',
}
convert_x = lambda x: {
# join the turns into a single array
'input_ids':
np.concatenate([
np.array(turn['input_ids']) for turn in x['turns']
]),
'attention_mask':
np.concatenate([
np.array(turn['attention_mask']) for turn in x['turns']
]),
'labels':
np.
concatenate([np.array(turn['labels']) for turn in x['turns']]),
}
convert_x = lambda x: (
ValueError('More than one turn found') if len(x['turns']) > 1 else {
'input_ids': np.array(x['turns'][0]['input_ids']),
'attention_mask': np.array(x['turns'][0]['attention_mask']),
'labels': np.array(x['turns'][0]['labels']),
}
)
elif 'concat_tokens' in columns[0]:
logging.info('Identified CPT data')
dtypes = {
Expand Down

0 comments on commit b5bf28c

Please sign in to comment.