-
Notifications
You must be signed in to change notification settings - Fork 0
/
custommistralv2v3.py
167 lines (144 loc) · 5.63 KB
/
custommistralv2v3.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
"""Module containing the CustomMistralV2V3PromptTokenizingStrategy class"""
# Import necessary modules and functions
import copy
import logging
from collections import defaultdict
from typing import Generator, List, Tuple
# Import from axolotl package
from axolotl.prompt_tokenizers import (
PromptTokenizingStrategy,
parse_tokenized_to_result,
tokenize_prompt_default,
)
# Set up logging
LOG = logging.getLogger("axolotl")
# Define a constant token ID to ignore
IGNORE_TOKEN_ID = -100
class CustomMistralV2V3PromptTokenizingStrategy(PromptTokenizingStrategy):
"""
Tokenizing strategy for CustomMistralV2V3.
"""
def __init__(self, prompter, tokenizer, *args, **kwargs):
# Call the superclass' constructor
super().__init__(prompter, tokenizer, *args, **kwargs)
def tokenize_prompt(self, prompt):
# Tokenize the prompt based on its conversations
result, current_len = tokenize_prompt_default()
# We don't want to remove the BOS token for the first turn
strip_bos = False
# Sometimes it gets named 'conversations' and other times 'conversation'
if "conversations" in prompt:
conversation_name = "conversations"
elif "conversation" in prompt:
conversation_name = "conversation"
else:
LOG.warning(f"sample does not contain 'conversations' or 'conversation'")
exit()
# Iterate over each conversation turn in the prompt
num_turns = len(prompt[conversation_name])
for i, turn in enumerate(prompt[conversation_name]):
# Strip BOS token and add a new line to the beginning if it's not the first turn
if i == 0:
strip_bos = False
add_new_line = ""
else:
strip_bos = True
add_new_line = "\n"
# Check if this is the last turn, so we know to add the EOS token
if i == num_turns - 1:
end_of_text = True
else:
end_of_text = False
# Get correct roles and messages
sharegpt_from, sharegpt_value = turn["from"].strip(), turn["value"].strip()
if sharegpt_from == "system":
role_name = "system"
elif sharegpt_from == "human":
role_name = "user"
elif sharegpt_from == "human-chat":
role_name = "user"
sharegpt_value = f"{turn['name'].strip()}: {sharegpt_value}"
elif sharegpt_from == "gpt":
role_name = "assistant"
elif sharegpt_from == "gpt-chat":
role_name = "assistant"
sharegpt_value = f"{turn['name'].strip()}: {sharegpt_value}"
else:
LOG.warning(f"'from' contains an unhandled string: {sharegpt_from}")
exit()
# Get tokens which will be masked out if using train_on_inputs: false
#prefix = self._tokenize(
# "",
# add_eos_token=False,
# strip_bos_token=strip_bos,
#)
# Get entire tokenized turn
if role_name == "user":
res = self._tokenize(
f"[INST] {sharegpt_value.strip()}[/INST]",
add_eos_token=end_of_text,
strip_bos_token=strip_bos,
)
elif role_name == "assistant":
res = self._tokenize(
f" {sharegpt_value.strip()}</s>",
add_eos_token=end_of_text,
strip_bos_token=strip_bos,
)
elif role_name == "system":
LOG.warning("sample contains unsupported system role, skipping")
res = self._tokenize(
"",
add_eos_token=end_of_text,
strip_bos_token=strip_bos,
)
# Handle masked user turn
if (
self.train_on_inputs is False
and (
sharegpt_from == "system"
or sharegpt_from == "human"
or sharegpt_from == "human-chat"
)
):
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
# Handle partially masked model turn
#elif (
# self.train_on_inputs is False
# and (
# sharegpt_from == "gpt"
# or sharegpt_from == "gpt-chat"
# )
#):
# labels = (
# [IGNORE_TOKEN_ID] * len(prefix["input_ids"]) # Mask the prefix
# + [*copy.deepcopy(res["input_ids"])][len(prefix["input_ids"]):]
# )
# Handle unmasked turn
else:
labels = res["input_ids"]
# Parse tokenized result and update current length
result, current_len = parse_tokenized_to_result(
result,
current_len,
res,
labels,
pad_token_id=self.tokenizer.pad_token_id,
)
return result
# TODO: Remove this as it doesn't get used
class CustomMistralV2V3Prompter:
"""
Prompter for CustomMistralV2V3.
"""
def __init__(self, *args, **kwargs):
# Constructor does nothing
pass
# Function to load the CustomMistralV2V3PromptTokenizingStrategy
def load(tokenizer, cfg):
return CustomMistralV2V3PromptTokenizingStrategy(
CustomMistralV2V3Prompter(), # TODO: Remove this as it doesn't get used
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len
)