forked from jingyaogong/minimind
-
Notifications
You must be signed in to change notification settings - Fork 1
/
train_tokenizer.py
162 lines (134 loc) · 5.26 KB
/
train_tokenizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import random
from tqdm import tqdm
from transformers import AutoTokenizer
import json
from datasets import load_dataset
from tokenizers import (
decoders,
models,
normalizers,
pre_tokenizers,
processors,
trainers,
Tokenizer,
)
import os
random.seed(42)
def train_tokenizer():
# 读取JSONL文件并提取文本数据
def read_texts_from_jsonl(file_path):
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
data = json.loads(line)
yield data['text']
data_path = './dataset/tokenizer_train.jsonl'
# 初始化tokenizer
tokenizer = Tokenizer(models.BPE())
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
# 定义特殊token
special_tokens = ["<unk>", "<s>", "</s>"]
# 设置训练器并添加特殊token
trainer = trainers.BpeTrainer(
vocab_size=6400,
special_tokens=special_tokens, # 确保这三个token被包含
show_progress=True,
initial_alphabet=pre_tokenizers.ByteLevel.alphabet()
)
# 读取文本数据
texts = read_texts_from_jsonl(data_path)
# 训练tokenizer
tokenizer.train_from_iterator(texts, trainer=trainer)
# 设置解码器
tokenizer.decoder = decoders.ByteLevel()
# 检查特殊token的索引
assert tokenizer.token_to_id("<unk>") == 0
assert tokenizer.token_to_id("<s>") == 1
assert tokenizer.token_to_id("</s>") == 2
# 保存tokenizer
tokenizer_dir = "./model/minimind_tokenizer"
os.makedirs(tokenizer_dir, exist_ok=True)
tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
tokenizer.model.save("./model/minimind_tokenizer")
# 手动创建配置文件
config = {
"add_bos_token": False,
"add_eos_token": False,
"add_prefix_space": True,
"added_tokens_decoder": {
"0": {
"content": "<unk>",
"lstrip": False,
"normalized": False,
"rstrip": False,
"single_word": False,
"special": True
},
"1": {
"content": "<s>",
"lstrip": False,
"normalized": False,
"rstrip": False,
"single_word": False,
"special": True
},
"2": {
"content": "</s>",
"lstrip": False,
"normalized": False,
"rstrip": False,
"single_word": False,
"special": True
}
},
"additional_special_tokens": [],
"bos_token": "<s>",
"clean_up_tokenization_spaces": False,
"eos_token": "</s>",
"legacy": True,
"model_max_length": 1000000000000000019884624838656,
"pad_token": None,
"sp_model_kwargs": {},
"spaces_between_special_tokens": False,
"tokenizer_class": "PreTrainedTokenizerFast",
"unk_token": "<unk>",
"use_default_system_prompt": False,
"chat_template": "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{% endif %}{% if system_message is defined %}{{ system_message }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<s>user\\n' + content + '</s>\\n<s>assistant\\n' }}{% elif message['role'] == 'assistant' %}{{ content + '</s>' + '\\n' }}{% endif %}{% endfor %}"
}
# 保存配置文件
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w", encoding="utf-8") as config_file:
json.dump(config, config_file, ensure_ascii=False, indent=4)
print("Tokenizer training completed and saved.")
def eval_tokenizer():
from transformers import AutoTokenizer
# 加载预训练的tokenizer
tokenizer = AutoTokenizer.from_pretrained("./model/minimind_tokenizer")
messages = [
{"role": "system", "content": "你是一个优秀的聊天机器人,总是给我正确的回应!"},
{"role": "user", "content": '是椭圆形的'},
{"role": "assistant", "content": '456'},
{"role": "user", "content": '456'},
{"role": "assistant", "content": '789'}
]
new_prompt = tokenizer.apply_chat_template(
messages,
tokenize=False
)
print(new_prompt)
# 获取词汇表大小(不包括特殊符号)
print('tokenizer词表大小:', tokenizer.vocab_size)
# 获取实际词汇表长度(包括特殊符号)
actual_vocab_size = len(tokenizer)
print('qwen实际词表长度:', actual_vocab_size)
new_prompt = 'wenjie,椭圆和⚪的关系是什么呢?因为明天下午要带家人去下医院,所以申请上午在家办公,因为明天下午要带家人去下医院,所以申请上午在家办公,因为明天下午要带家人去下医院,所以申请上午在家办公,下午请半天假~@LWJWe '
print(new_prompt)
model_inputs = tokenizer(new_prompt)
print(model_inputs)
print('长度:', len(model_inputs['input_ids']))
input_ids_ = model_inputs['input_ids']
response = tokenizer.decode(input_ids_)
print(response, end='')
def main():
# train_tokenizer()
eval_tokenizer()
if __name__ == '__main__':
main()