diff --git a/convert.py b/convert.py index f3bf1798089cc..54dba5979cb38 100644 --- a/convert.py +++ b/convert.py @@ -142,6 +142,7 @@ def find_n_mult(n_ff: int, n_embd: int) -> int: @dataclass class Params: n_vocab: int + n_vocab_base: int n_embd: int n_mult: int n_head: int @@ -169,6 +170,7 @@ def guessed(model: 'LazyModel') -> 'Params': return Params( n_vocab = n_vocab, + n_vocab_base=n_vocab, n_embd = n_embd, n_mult = 256, n_head = n_head, @@ -191,6 +193,7 @@ def loadHFTransformerJson(model: 'LazyModel', config_path: 'Path') -> 'Params': return Params( n_vocab = n_vocab, + n_vocab_base=n_vocab, n_embd = n_embd, n_mult = n_mult, n_head = n_head, @@ -215,6 +218,7 @@ def loadOriginalParamsJson(model: 'LazyModel', config_path: 'Path') -> 'Params': return Params( n_vocab = n_vocab, + n_vocab_base=n_vocab, n_embd = n_embd, n_mult = n_mult, n_head = n_head, @@ -239,7 +243,7 @@ def load(model_plus: 'ModelPlus') -> 'Params': class SentencePieceVocab: - def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path], vocabtype: Optional[str]) -> None: + def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path], fname_special_tokens: Optional[Path], fname_tokenizer_config: Optional[Path], vocabtype: Optional[str]) -> None: self.vocabtype = vocabtype if self.vocabtype == "bpe": self.sentencepiece_tokenizer = json.loads(open(str(fname_tokenizer)).read()) @@ -264,35 +268,72 @@ def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path], vo self.vocab_size: int = self.vocab_size_base + len(self.added_tokens_list) self.fname_tokenizer = fname_tokenizer self.fname_added_tokens = fname_added_tokens + self.special_tokens_map: Dict[int, str] = {} + + TOKEN_NAME_TO_ID: Dict[str, int] = { + "unk_token": self.sentencepiece_tokenizer.unk_id(), + "bos_token": self.sentencepiece_tokenizer.bos_id(), + "eos_token": self.sentencepiece_tokenizer.eos_id(), + "pad_token": self.sentencepiece_tokenizer.pad_id() + } + + tokenizer_config: Dict[str, Any] + if fname_tokenizer_config is not None: + tokenizer_config = json.load(open(fname_tokenizer_config)) + else: + tokenizer_config = {} + for key, value in tokenizer_config.items(): + if not isinstance(value, dict) and not isinstance(value, str): + continue + token_id = TOKEN_NAME_TO_ID.get(key, -1) + if token_id == -1: + continue + self.special_tokens_map[token_id] = value["content"] if isinstance(value, dict) else value + + special_tokens: Dict[str, Any] + if fname_special_tokens is not None: + special_tokens = json.load(open(fname_special_tokens)) + else: + special_tokens = {} + for key, value in special_tokens.items(): + if not isinstance(value, dict) and not isinstance(value, str): + continue + token_id = TOKEN_NAME_TO_ID.get(key, -1) + if token_id == -1 or token_id in self.special_tokens_map: + continue + self.special_tokens_map[token_id] = value["content"] if isinstance(value, dict) else value def sentencepiece_tokens(self) -> Iterable[Tuple[bytes, float]]: tokenizer = self.sentencepiece_tokenizer if self.vocabtype == "bpe": - from transformers.models.gpt2 import tokenization_gpt2 - byte_encoder = tokenization_gpt2.bytes_to_unicode() - byte_decoder = {v: k for k, v in byte_encoder.items()} - for i, item in enumerate(tokenizer): - text: bytes - text = b''.join([x.to_bytes(1, byteorder='big') for x in [byte_decoder[y] for y in item]]) - score: float = -i - yield text, score + from transformers.models.gpt2 import tokenization_gpt2 + byte_encoder = tokenization_gpt2.bytes_to_unicode() + byte_decoder = {v: k for k, v in byte_encoder.items()} + for i, item in enumerate(tokenizer): + text: bytes + text = b''.join([x.to_bytes(1, byteorder='big') for x in [byte_decoder[y] for y in item]]) + score: float = -i + yield text, score else: - for i in range(tokenizer.vocab_size()): - text: bytes - if tokenizer.is_unknown(i): - text = " \u2047 ".encode("utf-8") - elif tokenizer.is_control(i): - text = b"" - elif tokenizer.is_byte(i): - piece = tokenizer.id_to_piece(i) - if len(piece) != 6: - raise Exception(f"Invalid token: {piece}") - byte_value = int(piece[3:-1], 16) - text = struct.pack("B", byte_value) - else: - text = tokenizer.id_to_piece(i).replace("\u2581", " ").encode("utf-8") - score: float = tokenizer.get_score(i) - yield text, score + special_tokens = [tokenizer.bos_id(), tokenizer.eos_id(), tokenizer.pad_id()] + for i in range(tokenizer.vocab_size()): + text: bytes + if tokenizer.is_unknown(i): + text = self.special_tokens_map.get(i, " \u2047 ").encode("utf-8") + elif i in special_tokens: + text = self.special_tokens_map.get(i, "").encode("utf-8") + elif tokenizer.is_control(i): + text = b"" + elif tokenizer.is_byte(i): + piece = tokenizer.id_to_piece(i) + if len(piece) != 6: + raise Exception(f"Invalid token: {piece}") + byte_value = int(piece[3:-1], 16) + text = struct.pack("B", byte_value) + else: + text = tokenizer.id_to_piece(i).replace("\u2581", " ").encode("utf-8") + score: float = tokenizer.get_score(i) + yield text, score def added_tokens(self) -> Iterable[Tuple[bytes, float]]: for text in self.added_tokens_list: @@ -303,6 +344,12 @@ def all_tokens(self) -> Iterable[Tuple[bytes, float]]: yield from self.sentencepiece_tokens() yield from self.added_tokens() + def all_special_tokens(self) -> Iterable[int]: + for token_id in self.special_tokens_map.keys(): + yield token_id + for i in range(len(self.added_tokens_list)): + yield self.vocab_size_base + i + def __repr__(self) -> str: return f"" @@ -310,11 +357,16 @@ def __repr__(self) -> str: class GGMLVocab: def __init__(self, tokens: List[Tuple[bytes, float]]): self.tokens = tokens + self.special_tokens = [] self.vocab_size = len(tokens) + self.vocab_size_base = 0 def all_tokens(self) -> Iterable[Tuple[bytes, float]]: return self.tokens + def all_special_tokens(self) -> Iterable[int]: + return self.special_tokens + def __repr__(self) -> str: return f"" @@ -1072,10 +1124,10 @@ def write_file_header(self, params: Params, file_type: GGMLFileType) -> None: params.n_mult, params.n_head, params.n_layer, - params.n_embd // params.n_head, # rot (obsolete) + params.n_vocab_base | 0xF0000000, # reuse obsolete rot value to store vocab_base file_type.value, ] - self.fout.write(struct.pack("i" * len(values), *values)) + self.fout.write(struct.pack("I" * len(values), *values)) def write_tensor_header(self, name: str, shape: Sequence[int], data_type: DataType) -> None: sname = name.encode('utf-8') @@ -1093,7 +1145,8 @@ def write_vocab(self, vocab: Vocab) -> None: @staticmethod def write_vocab_only(fname_out: Path, vocab: Vocab) -> None: of = OutputFile(fname_out) - params = Params(n_vocab=vocab.vocab_size, n_embd=0, n_mult=0, n_head=1, n_layer=0) + params = Params(n_vocab=vocab.vocab_size, n_vocab_base=vocab.vocab_size_base, n_embd=0, n_mult=0, + n_head=1, n_layer=0) of = OutputFile(fname_out) of.write_file_header(params, file_type=GGMLFileType.AllF32) of.write_vocab(vocab) @@ -1249,8 +1302,10 @@ def load_vocab(path: Path, vocabtype: Optional[str]) -> SentencePieceVocab: f"Could not find tokenizer.model in {path} or its parent; " "if it's in another directory, pass the directory as --vocab-dir") added_tokens_path = path.parent / "added_tokens.json" + special_tokens_path = path.parent / "special_tokens_map.json" + tokenizer_config_path = path.parent / "tokenizer_config.json" print(f"Loading vocab file {path}") - return SentencePieceVocab(path, added_tokens_path if added_tokens_path.exists() else None, + return SentencePieceVocab(path, added_tokens_path if added_tokens_path.exists() else None, special_tokens_path if special_tokens_path.exists() else None, tokenizer_config_path if tokenizer_config_path.exists() else None, vocabtype) @@ -1313,6 +1368,7 @@ def main(args_in: Optional[List[str]] = None) -> None: vocab_dir = args.vocab_dir if args.vocab_dir else model_plus.paths[0].parent vocab = load_vocab(vocab_dir, args.vocabtype) params = Params.load(model_plus) + params.n_vocab_base = vocab.vocab_size_base model = model_plus.model model = do_necessary_conversions(model, params) output_type = pick_output_type(model, args.outtype) diff --git a/llama.cpp b/llama.cpp index 39aefd499dd0c..44104be66d710 100644 --- a/llama.cpp +++ b/llama.cpp @@ -181,13 +181,13 @@ static const std::map & VRAM_REQ_SCRATCH_PER_CONTEXT() // default hparams (LLaMA 7B) struct llama_hparams { uint32_t n_vocab = 32000; + uint32_t n_vocab_base = 32000; uint32_t n_ctx = 512; // this is provided as user input? uint32_t n_embd = 4096; uint32_t n_mult = 256; uint32_t n_head = 32; uint32_t n_head_kv = 32; uint32_t n_layer = 32; - uint32_t n_rot = 64; // LLaMAv2 // TODO: load from model data hparams @@ -277,6 +277,12 @@ struct llama_vocab { std::unordered_map token_to_id; std::vector id_to_token; + + std::unordered_map special_token_to_id; + + void add_special_token(const token & word, id token_id) { + special_token_to_id[word] = token_id; + } }; struct llama_model { @@ -509,6 +515,7 @@ struct llama_file_loader { read_hparams(); read_vocab(); read_tensor_metadata(tensors_map); + set_vocab_sp(); } void read_magic() { uint32_t magic = file.read_u32(); @@ -543,7 +550,8 @@ struct llama_file_loader { hparams.n_mult = file.read_u32(); hparams.n_head = file.read_u32(); hparams.n_layer = file.read_u32(); - hparams.n_rot = file.read_u32(); + hparams.n_vocab_base = file.read_u32(); + hparams.n_vocab_base = (hparams.n_vocab_base & 0xF0000000) == 0 ? hparams.n_vocab : (hparams.n_vocab_base & ~0xF0000000); // this bitwise operation is necessary for compatibility with older models hparams.ftype = (enum llama_ftype) file.read_u32(); // LLaMAv2 @@ -612,6 +620,17 @@ struct llama_file_loader { tensors_map.name_to_idx[name] = tensors_map.tensors.size() - 1; } } + void set_vocab_sp() { + uint32_t vocab_sp = 3 + hparams.n_vocab - hparams.n_vocab_base; + vocab.special_token_to_id.reserve(vocab_sp); + for (uint32_t i = 0; i < vocab_sp; i++) { + llama_vocab::id token_id = i > 2 ? hparams.n_vocab_base + i : i; + const auto & word = vocab.id_to_token[token_id].tok; + if (!word.empty()) { + vocab.add_special_token(word, token_id); + } + } + } }; struct llama_file_saver { @@ -635,7 +654,7 @@ struct llama_file_saver { file.write_u32(hparams.n_mult); file.write_u32(hparams.n_head); file.write_u32(hparams.n_layer); - file.write_u32(hparams.n_rot); + file.write_u32(hparams.n_vocab_base | 0xF0000000); // this bitwise operation is necessary for compatibility with older models file.write_u32(new_ftype); } void write_vocab() { @@ -1100,7 +1119,7 @@ static void llama_model_load_internal( fprintf(stderr, "%s: n_head = %u\n", __func__, hparams.n_head); fprintf(stderr, "%s: n_head_kv = %u\n", __func__, hparams.n_head_kv); fprintf(stderr, "%s: n_layer = %u\n", __func__, hparams.n_layer); - fprintf(stderr, "%s: n_rot = %u\n", __func__, hparams.n_rot); // a.k.a. n_embd_head, n_head_dim + fprintf(stderr, "%s: n_rot = %u\n", __func__, hparams.n_embd/hparams.n_head); // a.k.a. n_embd_head, n_head_dim fprintf(stderr, "%s: n_gqa = %u\n", __func__, hparams.n_gqa()); fprintf(stderr, "%s: rnorm_eps = %.1e\n", __func__, hparams.f_rms_norm_eps); fprintf(stderr, "%s: n_ff = %u\n", __func__, n_ff); @@ -1418,7 +1437,7 @@ static struct ggml_cgraph * llama_build_graph( const int64_t n_embd_head = hparams.n_embd_head(); const int64_t n_embd_gqa = hparams.n_embd_gqa(); - LLAMA_ASSERT(n_embd_head == hparams.n_rot); + LLAMA_ASSERT(n_embd_head == hparams.n_embd/hparams.n_head); const float freq_base = hparams.rope_freq_base; const float freq_scale = hparams.rope_freq_scale; @@ -1960,18 +1979,20 @@ struct llama_sp_bigram { struct llama_tokenizer { llama_tokenizer(const llama_vocab & vocab): vocab_(vocab) {} - void tokenize(const std::string & text, std::vector & output) { + void tokenize(const char * text, size_t len, std::vector & output) { + symbols_.clear(); + // split string into utf8 chars int index = 0; size_t offs = 0; - while (offs < text.size()) { + while (offs < len) { llama_sp_symbol sym; - size_t char_len = std::min(text.size() - offs, utf8_len(text[offs])); - sym.text = text.c_str() + offs; + size_t char_len = std::min(len - offs, utf8_len(text[offs])); + sym.text = text + offs; sym.n = char_len; offs += char_len; sym.prev = index - 1; - sym.next = offs == text.size() ? -1 : index + 1; + sym.next = offs == len ? -1 : index + 1; index++; symbols_.emplace_back(sym); } @@ -2074,7 +2095,45 @@ static std::vector llama_tokenize(const llama_vocab & vocab, co output.push_back(llama_token_bos()); } - tokenizer.tokenize(text, output); + if (vocab.special_token_to_id.empty()) { + tokenizer.tokenize(text.c_str(), text.size(), output); + return output; + } + + size_t delim_start = 0; + size_t last_delim_end = 0; + + while (delim_start < text.size()) { + size_t delim_end = 0; + llama_vocab::id token_id = -1; + + for (const auto & mit : vocab.special_token_to_id) { + const std::string & delimiter = mit.first; + size_t end = delim_start + delimiter.size(); + if (end <= text.size() && text.compare(delim_start, delimiter.size(), delimiter) == 0) { + if (token_id == -1 || end > delim_end) { + token_id = mit.second; + delim_end = end; + } + } + } + + if (token_id != -1) { + if (last_delim_end < delim_start) { + tokenizer.tokenize(text.c_str() + last_delim_end, delim_start - last_delim_end, output); + } + output.push_back(token_id); + delim_start = delim_end; + last_delim_end = delim_end; + } else { + delim_start++; + } + } + + if (last_delim_end < text.size()) { + tokenizer.tokenize(text.c_str() + last_delim_end, text.size() - last_delim_end, output); + } + return output; } @@ -4212,6 +4271,10 @@ llama_token llama_token_nl() { return 13; } +void llama_add_special_token(struct llama_model * model, const char * token, llama_token token_id) { + model->vocab.add_special_token(token, token_id); +} + struct llama_timings llama_get_timings(struct llama_context * ctx) { struct llama_timings result = { /*.t_start_ms =*/ 1e-3 * ctx->t_start_us, diff --git a/llama.h b/llama.h index fa1977f2d9492..519ee716d0e63 100644 --- a/llama.h +++ b/llama.h @@ -373,6 +373,11 @@ extern "C" { LLAMA_API llama_token llama_token_eos(); // end-of-sentence LLAMA_API llama_token llama_token_nl(); // next-line + LLAMA_API void llama_add_special_token( + struct llama_model * model, + const char * token, + llama_token token_id); + // Grammar // LLAMA_API struct llama_grammar * llama_grammar_init( diff --git a/tests/test-tokenizer-0.cpp b/tests/test-tokenizer-0.cpp index 87fde16453d25..3472180343c24 100644 --- a/tests/test-tokenizer-0.cpp +++ b/tests/test-tokenizer-0.cpp @@ -14,6 +14,9 @@ static const std::map> & k_tests() { " this is 🦙.cpp", { 1, 445, 338, 29871, 243, 162, 169, 156, 29889, 8223, }, }, { "w048 7tuijk dsdfhu", { 1, 29893, 29900, 29946, 29947, 29871, 29955, 9161, 13535, 18031, 2176, 6905, }, }, { "нещо на Български", { 1, 821, 4851, 665, 1386, 29713, 1305, }, }, + { "<🦙>test extra_id_1 test", { 1, 32004, 1688, 29871, 32001, 259, 1243, }, }, + { "<🦙>test extra_id_100 test", { 1, 32004, 1688, 29871, 32002, 1243, }, }, + { "<🦙>test extra_id_200 test", { 1, 32004, 1688, 321, 32003, 1243, }, }, }; return _k_tests; }; @@ -46,6 +49,11 @@ int main(int argc, char **argv) { return 1; } + llama_add_special_token(model, "extra_id_1", 32001); + llama_add_special_token(model, "extra_id_100", 32002); + llama_add_special_token(model, "xtra_id_200", 32003); + llama_add_special_token(model, "<🦙>", 32004); + ctx = llama_new_context_with_model(model, lparams); if (ctx == NULL) {