From 6f7a79ecd8e1ca1cf204a24c1bbbfaff8b88377c Mon Sep 17 00:00:00 2001 From: Charles Goddard Date: Sat, 30 Nov 2024 16:08:49 -0800 Subject: [PATCH] Add option to pad embeddings to multiple --- mergekit/merge.py | 11 +++++++++-- mergekit/plan.py | 3 +++ mergekit/tokenizer/config.py | 1 + mergekit/tokenizer/embed.py | 12 +++++++++++- tests/test_tokenizer.py | 32 +++++++++++++++++++++++++++++++- 5 files changed, 55 insertions(+), 4 deletions(-) diff --git a/mergekit/merge.py b/mergekit/merge.py index 60189f44..2d659505 100644 --- a/mergekit/merge.py +++ b/mergekit/merge.py @@ -98,7 +98,10 @@ def run_merge( tokenizer = value.tokenizer if tokenizer: - _update_config_vocab(cfg_out, tokenizer) + pad_to_multiple_of = None + if merge_config.tokenizer and merge_config.tokenizer.pad_to_multiple_of: + pad_to_multiple_of = merge_config.tokenizer.pad_to_multiple_of + _update_config_vocab(cfg_out, tokenizer, pad_to_multiple_of=pad_to_multiple_of) logging.info("Saving config") cfg_out.save_pretrained(out_path) @@ -263,9 +266,13 @@ def _model_out_config( def _update_config_vocab( config: transformers.PretrainedConfig, tokenizer: transformers.PreTrainedTokenizerBase, + pad_to_multiple_of: Optional[int] = None, ): + vocab_size = len(tokenizer.get_vocab()) + if pad_to_multiple_of and vocab_size % pad_to_multiple_of: + vocab_size = vocab_size + pad_to_multiple_of - (vocab_size % pad_to_multiple_of) try: - config.vocab_size = len(tokenizer.get_vocab()) + config.vocab_size = vocab_size except Exception as e: logging.warning( "Unable to set vocabulary size in output config - you may need to manually correct it.", diff --git a/mergekit/plan.py b/mergekit/plan.py index bdcd7004..5865becc 100644 --- a/mergekit/plan.py +++ b/mergekit/plan.py @@ -179,12 +179,15 @@ def plan_tensor( tensor_input_task = gather_tensors if self._tokenizer_task and weight.is_embed: token_cfg = {} + pad_to_multiple = None if cfg_reader.config.tokenizer: token_cfg = cfg_reader.config.tokenizer.tokens + pad_to_multiple = cfg_reader.config.tokenizer.pad_to_multiple_of tensor_input_task = PermutedEmbeddings( gather_tensors=gather_tensors, tokenizer_task=self._tokenizer_task, tokens=token_cfg, + pad_to_multiple_of=pad_to_multiple, base_model=base_model, ) diff --git a/mergekit/tokenizer/config.py b/mergekit/tokenizer/config.py index 94208385..7bdaeca2 100644 --- a/mergekit/tokenizer/config.py +++ b/mergekit/tokenizer/config.py @@ -49,3 +49,4 @@ class TokenEmbeddingConfig(BaseModel, frozen=True): class TokenizerConfig(BaseModel, frozen=True): source: Union[ModelReference, Literal["union"], Literal["base"]] = "union" tokens: Optional[Dict[str, TokenEmbeddingConfig]] = None + pad_to_multiple_of: Optional[int] = None diff --git a/mergekit/tokenizer/embed.py b/mergekit/tokenizer/embed.py index 3cdb1840..a853d1af 100644 --- a/mergekit/tokenizer/embed.py +++ b/mergekit/tokenizer/embed.py @@ -33,6 +33,7 @@ class PermutedEmbeddings(Task[Dict[ModelReference, torch.Tensor]]): gather_tensors: GatherTensors tokenizer_task: BuildTokenizer tokens: Optional[ImmutableMap[str, TokenEmbeddingConfig]] + pad_to_multiple_of: Optional[int] base_model: Optional[ModelReference] def arguments(self) -> Dict[str, Task]: @@ -51,6 +52,10 @@ def execute( vocab = tokenizer.get_vocab() vocab_size = len(vocab) + if self.pad_to_multiple_of and vocab_size % self.pad_to_multiple_of: + vocab_size = ( + vocab_size // self.pad_to_multiple_of + 1 + ) * self.pad_to_multiple_of embed_size = tensors[models[0]].shape[1] assert all( t.shape[1] == embed_size for t in tensors.values() @@ -59,7 +64,7 @@ def execute( dtype = tensors[models[0]].dtype device = tensors[models[0]].device - token_configs = dict(**self.tokens) or {} + token_configs = dict(**(self.tokens or {})) tokens_to_average = self.assign_embedding_sources( permutations, models, vocab, token_configs ) @@ -105,6 +110,11 @@ def execute( logging.error( f"No embedding for token {repr(token)} in model {model}!" ) + + if vocab_size > len(vocab): + # as suggested by https://nlp.stanford.edu/~johnhew/vocab-expansion.html + avg_embed = torch.mean(new_embed[: len(vocab), :], dim=0) + new_embed[len(vocab) :, :] = avg_embed result[model] = new_embed return result diff --git a/tests/test_tokenizer.py b/tests/test_tokenizer.py index 17fafcc8..a799e8c4 100644 --- a/tests/test_tokenizer.py +++ b/tests/test_tokenizer.py @@ -7,7 +7,7 @@ import tokenizers import torch from common import make_picollama, run_and_check_merge -from transformers import LlamaTokenizerFast, PreTrainedTokenizerBase +from transformers import LlamaConfig, LlamaTokenizerFast, PreTrainedTokenizerBase from mergekit.config import InputModelDefinition, MergeConfiguration from mergekit.io import LazyTensorLoader @@ -270,6 +270,36 @@ def _check_embed(model_path: str): run_and_check_merge(config, validate=_check_embed) + def test_pad_to_multiple_of(self, model_chatml: str): + config = self.make_config( + [model_chatml], + base_model=model_chatml, + merge_method="linear", + tokenizer_config=TokenizerConfig( + source="base", + pad_to_multiple_of=16, + ), + ) + real_vocab_size = 64 + 2 + padded_size = (real_vocab_size // 16 + 1) * 16 + + def _check_result(model_path: str): + cfg = LlamaConfig.from_pretrained(model_path) + assert ( + cfg.vocab_size == padded_size + ), f"Expected vocab size {padded_size}, got {cfg.vocab_size}" + check_tokenizer( + expected_size=real_vocab_size, + must_contain=["<|im_start|>", "<|im_end|>"], + )(model_path) + + emb_out = ModelEmbeddings(model_path) + assert ( + emb_out.embed_tokens.shape[0] == padded_size + ), "Embedding size mismatch" + + run_and_check_merge(config, validate=_check_result) + def make_config( self, models: List[str],