-
Notifications
You must be signed in to change notification settings - Fork 1
/
tokenizer.py
135 lines (102 loc) · 4.51 KB
/
tokenizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
from dataclasses import dataclass
from typing import List
import torch
@dataclass
class TokenizerConfig:
vocab_str: str = "ابتةثجحخدذرزسشصضطظعغفقكلمنهويءآأؤإئ0123456789.؟ "
pad_token_id: int = 0
unk_token_id: int = 1
mask_token_id: int = 2
unk_token: str = '$'
num_special_token_ids: int = 3
class Tokenizer:
def __init__(self, config: TokenizerConfig):
self.vocab_str = config.vocab_str
self.pad_token_id = config.pad_token_id
self.unk_token_id = config.unk_token_id
self.mask_token_id = config.mask_token_id
self.unk_token = config.unk_token
self.char2id = {
char: idx + config.num_special_token_ids
for idx, char in enumerate(self.vocab_str)
}
self.id2char = {
id: char for char, id in self.char2id.items()
}
def token2id(self, char: str) -> int:
return self.char2id.get(char, self.unk_token_id)
def encode(
self,
text: str,
max_length: int = None,
padding: bool = False,
return_type: str = None,
) -> List[int] | torch.LongTensor:
"""
Args:
padding (bool): Adding padding tokens to `max_length` if True and `max_length` is not None.
"""
tokenized_text = [self.token2id(char) for char in text]
if max_length is not None:
tokenized_text = tokenized_text[:max_length]
if padding:
tokenized_text = self.pad(tokenized_text, max_length)
if return_type == "pt":
tokenized_text = torch.LongTensor(tokenized_text)
return tokenized_text
def decode(self, token_ids: List[int] | torch.LongTensor, skip_special_tokens: bool = True) -> str:
"""
Args:
token_ids (List[int] | torch.LongTensor): A list of token ids in the shape of (seq_len).
"""
if isinstance(token_ids, torch.LongTensor):
if len(token_ids.shape) != 1:
raise Exception("`decode` only supports tensors of a single dimension.")
token_ids = token_ids.tolist()
# Remove pad tokens
token_ids = [token_id for token_id in token_ids if token_id != self.pad_token_id]
if skip_special_tokens:
return ''.join([self.id2char[token_id] for token_id in token_ids if token_id in self.id2char])
return ''.join([self.id2char.get(token_id, self.unk_token) for token_id in token_ids])
def batch_encode(
self, texts: List[str], max_length: int = None, return_type: str = None, padding: str = 'longest'
) -> List[List[int]] | torch.LongTensor:
"""
Args:
padding (str): Specify padding type. Accepts 'longest' and 'max_length'
"""
tokenized_texts = [
self.encode(text, max_length=None, padding=False) for text in texts
]
batch_max_length = max(len(token_ids) for token_ids in tokenized_texts)
if max_length is None:
if padding == 'longest':
max_length = batch_max_length
else:
raise Exception('Padding is not set to "longest" and max_length is not specified.')
elif padding == 'longest':
max_length = min(max_length, batch_max_length)
tokenized_texts = [token_ids[:max_length] for token_ids in tokenized_texts]
tokenized_texts = [
self.pad(token_ids, max_length) for token_ids in tokenized_texts
]
if return_type == "pt":
tokenized_texts = torch.LongTensor(tokenized_texts)
return tokenized_texts
def pad(self, token_ids: List[int], max_length: int) -> List[int]:
return token_ids + [self.pad_token_id] * (max_length - len(token_ids))
if __name__ == '__main__':
config = TokenizerConfig()
tokenizer = Tokenizer(config)
tokenized = tokenizer.encode('ابجد')
assert len(tokenized) == 4
tokenized = tokenizer.encode('ابجد', max_length=2)
assert len(tokenized) == 2
tokenized = tokenizer.encode('ابجد', max_length=10, padding=True)
assert len(tokenized) == 10
tokenized = tokenizer.encode('ابجد', max_length=10, padding=True, return_type='pt')
assert tokenized.shape == torch.Size([10])
tokenized = tokenizer.batch_encode(['ابجد', 'هوز'], return_type='pt')
assert tokenized.shape == torch.Size([2, 4])
tokenized = tokenizer.batch_encode(['ابجد', 'هوز'], max_length=10, return_type='pt')
assert tokenized.shape == torch.Size([2, 10])