-
Notifications
You must be signed in to change notification settings - Fork 0
/
custom_dataset.alpaca.py
67 lines (53 loc) · 2.78 KB
/
custom_dataset.alpaca.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# Derivative of https://github.com/meta-llama/llama-recipes/blob/5e857601900e48e1d05f956f05e94ac64e78976a/recipes/finetuning/datasets/custom_dataset.py
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
# https://ai.meta.com/llama/license/
# For dataset details visit: https://huggingface.co/datasets/samsum
import os
import copy
import datasets
import itertools
B_INST, E_INST = "[INST]", "[/INST]"
def tokenize_dialog(dialog, tokenizer):
if tokenizer.vocab_size >= 128000:
dialog_tokens = tokenizer.apply_chat_template(dialog)
dialog_tokens = dialog_tokens[:-4] # Remove generation prompt <|start_header_id|>assistant<|end_header_id|>\n\n
eot_indices = [i for i,n in enumerate(dialog_tokens) if n == 128009]
labels = copy.copy(dialog_tokens)
last_idx = 0
for n, idx in enumerate(eot_indices):
if n % 2 == 1:
last_idx = idx
else:
labels[last_idx:idx+1] = [-100] * (idx-last_idx+1)
dialog_tokens = [dialog_tokens]
labels_tokens = [labels]
else:
prompt_tokens = [tokenizer.encode(f"{tokenizer.bos_token}{B_INST} {(prompt['content']).strip()} {E_INST}", add_special_tokens=False) for prompt in dialog[::2]]
answer_tokens = [tokenizer.encode(f"{answer['content'].strip()} {tokenizer.eos_token}", add_special_tokens=False) for answer in dialog[1::2]]
dialog_tokens = list(itertools.chain.from_iterable(zip(prompt_tokens, answer_tokens)))
#Add labels, convert prompt token to -100 in order to ignore in loss function
labels_tokens = [len(c)*[-100,] if i % 2 == 0 else c for i,c in enumerate(dialog_tokens)]
combined_tokens = {
"input_ids": list(itertools.chain(*(t for t in dialog_tokens))),
"labels": list(itertools.chain(*(t for t in labels_tokens))),
}
return dict(combined_tokens, attention_mask=[1]*len(combined_tokens["input_ids"]))
def get_custom_dataset(dataset_config, tokenizer, split):
dataset = datasets.load_dataset("json", data_files="datasets/alpaca/data/train.json", split="train")
def to_dialog(example):
return {
"dialog": [
{
"role": "user",
"content": f"Input: {example['input']}{2 * os.linesep}Instruction: {example['instruction']}",
},
{
"role": "assistant",
"content": example["output"],
}
]
}
dataset = dataset.map(lambda x: to_dialog(x), remove_columns=list(dataset.features))
dataset = dataset.map(lambda x: tokenize_dialog(x["dialog"], tokenizer), remove_columns=list(dataset.features))
return dataset