forked from DRAGNLabs/301r_retnet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_tokenizer.py
98 lines (75 loc) · 2.91 KB
/
train_tokenizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import dask
dask.config.set({"dataframe.query-planning": True})
import dask.dataframe as dd
import sys
import yaml
from pathlib import Path
from tokenizers import decoders, pre_tokenizers, processors, Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from transformers import PreTrainedTokenizerFast
from utils import Struct
def train_tokenizer(config):
print(f"Data dir: {config.raw_dataset_path}")
print("Loading dataset from disk")
train_dataset_path = Path(config.raw_dataset_path) / "train" / "*.parquet"
# Only load in train set, as that's all the tokenizer needs.
dataset = dd.read_parquet(path=train_dataset_path,
columns=[str(config.dataset_feature)]).compute()
print("Loaded!")
print("Creating tokenizer")
# Create BytePair Encoding tokenizer and trainer
tokenizer = Tokenizer(BPE(unk_token="<unk>"))
trainer = BpeTrainer(
vocab_size=config.vocab_size,
show_progress=True,
special_tokens=["<pad>", "<bos>", "<unk>"])
# Like GPT-2, we skip the normalizer and go directly to pre-tokenization.
# The option we add to ByteLevel here is to not add a space at the beginning
# of a sentence (which is the default otherwise)
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
print("Training tokenizer")
# Train tokenizer on only training data
tokenizer.train_from_iterator(
iter(dataset[config.dataset_feature]),
trainer=trainer)
# trim_offsets=False tells post-processor to keep spaces as part of tokens
tokenizer.post_processor = processors.TemplateProcessing(
single="<bos> $A",
special_tokens=[("<bos>", tokenizer.token_to_id("<bos>"))],
)
# Add decoder for converting tokens back to text
tokenizer.decoder = decoders.ByteLevel()
# Enable padding
tokenizer.enable_padding(
direction="right",
pad_id=0,
pad_token="<pad>",
length=config.seq_len + 1)
# Enable truncation
tokenizer.enable_truncation(
max_length=config.seq_len + 1,
direction="right")
# Wrap tokenizer with transformers library
tokenizer = PreTrainedTokenizerFast(
model_max_length=config.seq_len,
padding_side="right",
truncation_side="right",
bos_token="<bos>",
unk_token="<unk>",
pad_token="<pad>",
tokenizer_object=tokenizer)
print("Saving tokenizer to file...")
# Save tokenizer to file
tokenizer_save_path = Path(config.tokenizer_path)
tokenizer_save_path.mkdir(parents=True, exist_ok=True)
tokenizer.save_pretrained(tokenizer_save_path)
print("Done!")
if __name__ == "__main__":
args = sys.argv
config_path = args[1]
with open(config_path, "r") as f:
config = yaml.safe_load(f)
config = Struct(**config)
print("Training tokenizer...")
train_tokenizer(config)