Skip to content

Commit

Permalink
BERT tokenizer fixes (#6498)
Browse files Browse the repository at this point in the history
Key changes:
* BERT conversion: fix abuse of LlamaHfVocab, do not set BOS or EOS
* Nomic Embed conversion: pad vocab instead of slicing embedding tensor
* llama_tokenize: handle added special tokens like HF does
  • Loading branch information
cebtenzzre authored Apr 9, 2024
1 parent c4a3a4f commit 1b67731
Show file tree
Hide file tree
Showing 20 changed files with 221 additions and 194 deletions.
16 changes: 8 additions & 8 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2212,23 +2212,23 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
std::vector<llama_token> llama_tokenize(
const struct llama_context * ctx,
const std::string & text,
bool add_bos,
bool special) {
return llama_tokenize(llama_get_model(ctx), text, add_bos, special);
bool add_special,
bool parse_special) {
return llama_tokenize(llama_get_model(ctx), text, add_special, parse_special);
}

std::vector<llama_token> llama_tokenize(
const struct llama_model * model,
const std::string & text,
bool add_bos,
bool special) {
bool add_special,
bool parse_special) {
// upper limit for the number of tokens
int n_tokens = text.length() + add_bos;
int n_tokens = text.length() + 2 * add_special;
std::vector<llama_token> result(n_tokens);
n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos, special);
n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
if (n_tokens < 0) {
result.resize(-n_tokens);
int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos, special);
int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
GGML_ASSERT(check == -n_tokens);
} else {
result.resize(n_tokens);
Expand Down
8 changes: 4 additions & 4 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -223,14 +223,14 @@ void llama_batch_add(
std::vector<llama_token> llama_tokenize(
const struct llama_context * ctx,
const std::string & text,
bool add_bos,
bool special = false);
bool add_special,
bool parse_special = false);

std::vector<llama_token> llama_tokenize(
const struct llama_model * model,
const std::string & text,
bool add_bos,
bool special = false);
bool add_special,
bool parse_special = false);

// tokenizes a token into a piece
// should work similar to Python's `tokenizer.id_to_piece`
Expand Down
53 changes: 19 additions & 34 deletions convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,15 +227,14 @@ def _get_part_names(self):
return ("pytorch_model.bin",)
return (f"pytorch_model-{n:05}-of-{self.num_parts:05}.bin" for n in range(1, self.num_parts + 1))

def _set_vocab_gpt2(self):
dir_model = self.dir_model
hparams = self.hparams
# used for GPT-2 BPE and WordPiece vocabs
def get_basic_vocab(self) -> tuple[list[str], list[int]]:
tokens: list[str] = []
toktypes: list[int] = []

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(dir_model)
vocab_size = hparams.get("vocab_size", len(tokenizer.vocab))
tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
vocab_size = self.hparams.get("vocab_size", len(tokenizer.vocab))
assert max(tokenizer.vocab.values()) < vocab_size

reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()}
Expand All @@ -255,11 +254,15 @@ def _set_vocab_gpt2(self):
tokens.append(reverse_vocab[i])
toktypes.append(gguf.TokenType.NORMAL)

return tokens, toktypes

def _set_vocab_gpt2(self) -> None:
tokens, toktypes = self.get_basic_vocab()
self.gguf_writer.add_tokenizer_model("gpt2")
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)

special_vocab = gguf.SpecialVocab(dir_model, load_merges=True)
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
special_vocab.add_to_gguf(self.gguf_writer)

def _set_vocab_qwen(self):
Expand Down Expand Up @@ -2043,34 +2046,25 @@ def set_gguf_parameters(self):
self.gguf_writer.add_pooling_type(pooling_type)

def set_vocab(self):
# use huggingface vocab to get all tokens
vocab = LlamaHfVocab(self.dir_model, ignore_nonllama=True)
tokens, scores, toktypes = zip(*vocab.all_tokens())
assert len(tokens) == vocab.vocab_size
self.vocab_size = vocab.vocab_size
tokens, toktypes = self.get_basic_vocab()
self.vocab_size = len(tokens)

# we need this to validate the size of the token_type embeddings
# though currently we are passing all zeros to the token_type embeddings
n_token_types = len(set(toktypes))
self.gguf_writer.add_token_type_count(n_token_types)
self.gguf_writer.add_token_type_count(2) # "Sequence A" or "Sequence B"

# convert to phantom space vocab
def phantom(tok, typ):
if tok.startswith(b"[") and tok.endswith(b"]"):
def phantom(tok):
if tok.startswith("[") and tok.endswith("]"):
return tok
if tok.startswith(b"##"):
if tok.startswith("##"):
return tok[2:]
return b"\xe2\x96\x81" + tok
tokens = tuple(phantom(t, y) for t, y in zip(tokens, toktypes))

# set up bos and eos tokens (cls and sep)
self.gguf_writer.add_bos_token_id(vocab.tokenizer.cls_token_id)
self.gguf_writer.add_eos_token_id(vocab.tokenizer.sep_token_id)
return "\u2581" + tok
tokens = list(map(phantom, tokens))

# add vocab to gguf
self.gguf_writer.add_tokenizer_model("bert")
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_scores(scores)
self.gguf_writer.add_token_types(toktypes)

# handle special tokens
Expand Down Expand Up @@ -2142,16 +2136,6 @@ def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_rope_freq_base(self.hparams["rotary_emb_base"])

def get_tensors(self):
assert self.vocab_size is not None
for name, data in super().get_tensors():
# Nomic Embed's token embeddings tensor is padded, but llama.cpp wants tensor sizes to match exactly.
if name == 'embeddings.word_embeddings.weight' and data.shape[1] != self.vocab_size:
rounded_vocab_size = (self.vocab_size + 63) // 64 * 64
assert data.shape == (rounded_vocab_size, self.hparams["n_embd"])
data = data[:self.vocab_size, :]
yield name, data


@Model.register("GemmaForCausalLM")
class GemmaModel(Model):
Expand Down Expand Up @@ -2327,7 +2311,8 @@ def write_tensors(self):
data = data.astype(np.float32)

# if f16 desired, convert big float32 2-dim weight tensors to float16
if self.ftype == 1 and data_dtype == np.float32 and new_name.removesuffix(".weight").endswith((".ssm_in", ".ssm_out", "token_embd", "output")) and n_dims == 2:
new_weight_name = new_name[:-len(".weight")] if new_name.endswith(".weight") else ""
if self.ftype == 1 and data_dtype == np.float32 and new_weight_name.endswith((".ssm_in", ".ssm_out", "token_embd", "output")) and n_dims == 2:
data = data.astype(np.float16)

print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
Expand Down
2 changes: 2 additions & 0 deletions convert-persimmon-to-gguf.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
#!/usr/bin/env python3
from __future__ import annotations

import argparse
import os
import sys
Expand Down
21 changes: 10 additions & 11 deletions convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
import gguf

if TYPE_CHECKING:
from typing import TypeAlias
from typing_extensions import Self, TypeAlias

if hasattr(faulthandler, 'register') and hasattr(signal, 'SIGUSR1'):
faulthandler.register(signal.SIGUSR1)
Expand Down Expand Up @@ -517,17 +517,15 @@ class LlamaHfVocab(Vocab):
tokenizer_model = "llama"
name = "hfft"

def __init__(self, base_path: Path, ignore_nonllama: bool = False):
def __init__(self, base_path: Path):
fname_tokenizer = base_path / FAST_TOKENIZER_FILE
# if this fails, FileNotFoundError propagates to caller
with open(fname_tokenizer, encoding='utf-8') as f:
tokenizer_json = json.load(f)

# pre-check so we know if we need transformers
tokenizer_model: dict[str, Any] = tokenizer_json['model']
if ignore_nonllama:
pass # workaround incorrect use of this class for WordPiece
elif (
if (
tokenizer_model['type'] != 'BPE' or not tokenizer_model.get('byte_fallback', False)
or tokenizer_json['decoder']['type'] != 'Sequence'
):
Expand Down Expand Up @@ -647,16 +645,17 @@ def permute(weights: NDArray, n_head: int, n_head_kv: int) -> NDArray:


class Tensor(ABC):
ndarray: NDArray
data_type: DataType

@abstractmethod
def astype(self, data_type: DataType) -> Tensor: ...
def astype(self, data_type: DataType) -> Self: ...
@abstractmethod
def permute(self, n_head: int, n_head_kv: int) -> Tensor: ...
def permute(self, n_head: int, n_head_kv: int) -> Self: ...
@abstractmethod
def permute_part(self, n_part: int, n_head: int, n_head_kv: int) -> UnquantizedTensor: ...
def permute_part(self, n_part: int, n_head: int, n_head_kv: int) -> Self: ...
@abstractmethod
def part(self, n_part: int) -> UnquantizedTensor: ...
def part(self, n_part: int) -> Self: ...
@abstractmethod
def to_ggml(self) -> GGMLCompatibleTensor: ...

Expand All @@ -673,13 +672,13 @@ def __init__(self, ndarray: NDArray):
self.ndarray = ndarray
self.data_type = NUMPY_TYPE_TO_DATA_TYPE[ndarray.dtype]

def astype(self, data_type: DataType) -> Tensor:
def astype(self, data_type: DataType) -> UnquantizedTensor:
dtype = data_type.dtype
if self.data_type == DT_BF16:
self.ndarray = bf16_to_fp32(self.ndarray)
return UnquantizedTensor(self.ndarray.astype(dtype))

def to_ggml(self) -> UnquantizedTensor:
def to_ggml(self) -> Self:
return self

def permute_part(self, n_part: int, n_head: int, n_head_kv: int) -> UnquantizedTensor:
Expand Down
6 changes: 3 additions & 3 deletions examples/embedding/embedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,10 @@ int main(int argc, char ** argv) {
inputs.push_back(inp);
}

// add eos if not present
// add SEP if not present
for (auto & inp : inputs) {
if (inp.empty() || inp.back() != llama_token_eos(model)) {
inp.push_back(llama_token_eos(model));
if (inp.empty() || inp.back() != llama_token_sep(model)) {
inp.push_back(llama_token_sep(model));
}
}

Expand Down
3 changes: 2 additions & 1 deletion examples/imatrix/imatrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -349,12 +349,13 @@ static void process_logits(
static bool compute_imatrix(llama_context * ctx, const gpt_params & params, bool compute_ppl, int from_chunk) {

const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
GGML_ASSERT(llama_add_eos_token(llama_get_model(ctx)) != 1);
const int n_ctx = llama_n_ctx(ctx);

auto tim1 = std::chrono::high_resolution_clock::now();
fprintf(stderr, "%s: tokenizing the input ..\n", __func__);

std::vector<llama_token> tokens = ::llama_tokenize(ctx, params.prompt, add_bos);
std::vector<llama_token> tokens = ::llama_tokenize(ctx, params.prompt, true);

auto tim2 = std::chrono::high_resolution_clock::now();
fprintf(stderr, "%s: tokenization took %g ms\n",__func__,1e-3*std::chrono::duration_cast<std::chrono::microseconds>(tim2-tim1).count());
Expand Down
5 changes: 3 additions & 2 deletions examples/infill/infill.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ int main(int argc, char ** argv) {
LOG_TEE("%s\n", get_system_info(params).c_str());
}
const bool add_bos = llama_should_add_bos_token(model);
GGML_ASSERT(llama_add_eos_token(model) != 1);
LOG("add_bos: %d\n", add_bos);

bool suff_rm_leading_spc = params.escape;
Expand Down Expand Up @@ -279,10 +280,10 @@ int main(int argc, char ** argv) {
if (ctx_guidance) {
LOG("cfg_negative_prompt: \"%s\"\n", log_tostr(sparams.cfg_negative_prompt));

guidance_inp = ::llama_tokenize(ctx_guidance, sparams.cfg_negative_prompt, add_bos);
guidance_inp = ::llama_tokenize(ctx_guidance, sparams.cfg_negative_prompt, true);
LOG("guidance_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_guidance, guidance_inp).c_str());

std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, add_bos);
std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, true);
LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp).c_str());

original_prompt_len = original_inp.size();
Expand Down
3 changes: 1 addition & 2 deletions examples/llava/llava-cli.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,6 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_
int n_past = 0;

const int max_tgt_len = params->n_predict < 0 ? 256 : params->n_predict;
const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx_llava->ctx_llama));

std::string system_prompt, user_prompt;
size_t image_pos = prompt.find("<image>");
Expand Down Expand Up @@ -180,7 +179,7 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_
}
}

eval_string(ctx_llava->ctx_llama, system_prompt.c_str(), params->n_batch, &n_past, add_bos);
eval_string(ctx_llava->ctx_llama, system_prompt.c_str(), params->n_batch, &n_past, true);
llava_eval_image_embed(ctx_llava->ctx_llama, image_embed, params->n_batch, &n_past);
eval_string(ctx_llava->ctx_llama, user_prompt.c_str(), params->n_batch, &n_past, false);

Expand Down
5 changes: 1 addition & 4 deletions examples/lookahead/lookahead.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,10 @@ int main(int argc, char ** argv) {
std::tie(model, ctx) = llama_init_from_gpt_params(params);

// Tokenize the prompt
const bool add_bos = llama_should_add_bos_token(model);
LOG("add_bos tgt: %d\n", add_bos);

std::vector<llama_token> inp;
std::vector<llama_token> all;

inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);
inp = ::llama_tokenize(ctx, params.prompt, true, true);
all = inp;

const int max_context_size = llama_n_ctx(ctx);
Expand Down
4 changes: 1 addition & 3 deletions examples/lookup/lookup-create.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,8 @@ int main(int argc, char ** argv){
GGML_ASSERT(model != nullptr);

// tokenize the prompt
const bool add_bos = llama_should_add_bos_token(model);

std::vector<llama_token> inp;
inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);
inp = ::llama_tokenize(ctx, params.prompt, true, true);
fprintf(stderr, "%s: tokenization done\n", __func__);


Expand Down
5 changes: 1 addition & 4 deletions examples/lookup/lookup-stats.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,8 @@ int main(int argc, char ** argv){
GGML_ASSERT(llama_n_vocab(model) < (1 << 16));

// tokenize the prompt
const bool add_bos = llama_should_add_bos_token(model);
LOG("add_bos tgt: %d\n", add_bos);

std::vector<llama_token> inp;
inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);
inp = ::llama_tokenize(ctx, params.prompt, true, true);

llama_ngram_cache ngram_cache_context;
llama_ngram_cache ngram_cache_dynamic;
Expand Down
5 changes: 1 addition & 4 deletions examples/lookup/lookup.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,8 @@ int main(int argc, char ** argv){
GGML_ASSERT(llama_n_vocab(model) < (1 << 16));

// tokenize the prompt
const bool add_bos = llama_should_add_bos_token(model);
LOG("add_bos tgt: %d\n", add_bos);

std::vector<llama_token> inp;
inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);
inp = ::llama_tokenize(ctx, params.prompt, true, true);

llama_ngram_cache ngram_cache_context;
llama_ngram_cache ngram_cache_dynamic;
Expand Down
13 changes: 7 additions & 6 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ int main(int argc, char ** argv) {
}

const bool add_bos = llama_should_add_bos_token(model);
GGML_ASSERT(llama_add_eos_token(model) != 1);
LOG("add_bos: %d\n", add_bos);

std::vector<llama_token> embd_inp;
Expand All @@ -255,7 +256,7 @@ int main(int argc, char ** argv) {
if (params.chatml) {
params.prompt = "<|im_start|>system\n" + params.prompt + "<|im_end|>";
}
embd_inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);
embd_inp = ::llama_tokenize(ctx, params.prompt, true, true);
} else {
LOG("use session tokens\n");
embd_inp = session_tokens;
Expand All @@ -277,10 +278,10 @@ int main(int argc, char ** argv) {
if (ctx_guidance) {
LOG("cfg_negative_prompt: \"%s\"\n", log_tostr(sparams.cfg_negative_prompt));

guidance_inp = ::llama_tokenize(ctx_guidance, sparams.cfg_negative_prompt, add_bos, true);
guidance_inp = ::llama_tokenize(ctx_guidance, sparams.cfg_negative_prompt, true, true);
LOG("guidance_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_guidance, guidance_inp).c_str());

std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);
std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, true, true);
LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp).c_str());

original_prompt_len = original_inp.size();
Expand Down Expand Up @@ -339,14 +340,14 @@ int main(int argc, char ** argv) {
}

// prefix & suffix for instruct mode
const auto inp_pfx = ::llama_tokenize(ctx, "\n\n### Instruction:\n\n", add_bos, true);
const auto inp_sfx = ::llama_tokenize(ctx, "\n\n### Response:\n\n", false, true);
const auto inp_pfx = ::llama_tokenize(ctx, "\n\n### Instruction:\n\n", true, true);
const auto inp_sfx = ::llama_tokenize(ctx, "\n\n### Response:\n\n", false, true);

LOG("inp_pfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_pfx).c_str());
LOG("inp_sfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_sfx).c_str());

// chatml prefix & suffix
const auto cml_pfx = ::llama_tokenize(ctx, "\n<|im_start|>user\n", add_bos, true);
const auto cml_pfx = ::llama_tokenize(ctx, "\n<|im_start|>user\n", true, true);
const auto cml_sfx = ::llama_tokenize(ctx, "<|im_end|>\n<|im_start|>assistant\n", false, true);

LOG("cml_pfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, cml_pfx).c_str());
Expand Down
Loading

0 comments on commit 1b67731

Please sign in to comment.