Skip to content

Commit

Permalink
Pad embeds to multiple (#465)
Browse files Browse the repository at this point in the history
Add the ability to pad the output embeddings to a multiple of a
user-defined factor when merging tokenizers.

Config syntax example:
```yaml
merge_method: linear
models:
  - model: model_a
  - model: model_b
parameters:
  weight: 0.5
tokenizer:
  source: union
  pad_to_multiple_of: 64
```
  • Loading branch information
cg123 authored Dec 1, 2024
1 parent 8d1a10d commit 01e60a2
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 5 deletions.
11 changes: 9 additions & 2 deletions mergekit/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,10 @@ def run_merge(
tokenizer = value.tokenizer

if tokenizer:
_update_config_vocab(cfg_out, tokenizer)
pad_to_multiple_of = None
if merge_config.tokenizer and merge_config.tokenizer.pad_to_multiple_of:
pad_to_multiple_of = merge_config.tokenizer.pad_to_multiple_of
_update_config_vocab(cfg_out, tokenizer, pad_to_multiple_of=pad_to_multiple_of)

logging.info("Saving config")
cfg_out.save_pretrained(out_path)
Expand Down Expand Up @@ -263,9 +266,13 @@ def _model_out_config(
def _update_config_vocab(
config: transformers.PretrainedConfig,
tokenizer: transformers.PreTrainedTokenizerBase,
pad_to_multiple_of: Optional[int] = None,
):
vocab_size = len(tokenizer.get_vocab())
if pad_to_multiple_of and vocab_size % pad_to_multiple_of:
vocab_size = vocab_size + pad_to_multiple_of - (vocab_size % pad_to_multiple_of)
try:
config.vocab_size = len(tokenizer.get_vocab())
config.vocab_size = vocab_size
except Exception as e:
logging.warning(
"Unable to set vocabulary size in output config - you may need to manually correct it.",
Expand Down
3 changes: 3 additions & 0 deletions mergekit/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,12 +182,15 @@ def plan_tensor(
tensor_input_task = gather_tensors
if self._tokenizer_task and weight.is_embed:
token_cfg = {}
pad_to_multiple = None
if cfg_reader.config.tokenizer:
token_cfg = cfg_reader.config.tokenizer.tokens
pad_to_multiple = cfg_reader.config.tokenizer.pad_to_multiple_of
tensor_input_task = PermutedEmbeddings(
gather_tensors=gather_tensors,
tokenizer_task=self._tokenizer_task,
tokens=token_cfg,
pad_to_multiple_of=pad_to_multiple,
base_model=base_model,
)

Expand Down
2 changes: 1 addition & 1 deletion mergekit/scripts/tokensurgeon.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def main(
tokenizer.save_pretrained(out_path)
cfg_out = arch_info.config
try:
cfg_out.vocab_size = tokenizer.vocab_size
cfg_out.vocab_size = new_embed.shape[0]
except AttributeError:
LOG.error(
"Could not set vocab size in config.json - you may need to update it manually."
Expand Down
1 change: 1 addition & 0 deletions mergekit/tokenizer/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,4 @@ class TokenEmbeddingConfig(BaseModel, frozen=True):
class TokenizerConfig(BaseModel, frozen=True):
source: Union[ModelReference, Literal["union"], Literal["base"]] = "union"
tokens: Optional[Dict[str, TokenEmbeddingConfig]] = None
pad_to_multiple_of: Optional[int] = None
12 changes: 11 additions & 1 deletion mergekit/tokenizer/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class PermutedEmbeddings(Task[Dict[ModelReference, torch.Tensor]]):
gather_tensors: GatherTensors
tokenizer_task: BuildTokenizer
tokens: Optional[ImmutableMap[str, TokenEmbeddingConfig]]
pad_to_multiple_of: Optional[int]
base_model: Optional[ModelReference]

def arguments(self) -> Dict[str, Task]:
Expand All @@ -51,6 +52,10 @@ def execute(

vocab = tokenizer.get_vocab()
vocab_size = len(vocab)
if self.pad_to_multiple_of and vocab_size % self.pad_to_multiple_of:
vocab_size = (
vocab_size // self.pad_to_multiple_of + 1
) * self.pad_to_multiple_of
embed_size = tensors[models[0]].shape[1]
assert all(
t.shape[1] == embed_size for t in tensors.values()
Expand All @@ -59,7 +64,7 @@ def execute(
dtype = tensors[models[0]].dtype
device = tensors[models[0]].device

token_configs = dict(**self.tokens) or {}
token_configs = dict(**(self.tokens or {}))
tokens_to_average = self.assign_embedding_sources(
permutations, models, vocab, token_configs
)
Expand Down Expand Up @@ -105,6 +110,11 @@ def execute(
logging.error(
f"No embedding for token {repr(token)} in model {model}!"
)

if vocab_size > len(vocab):
# as suggested by https://nlp.stanford.edu/~johnhew/vocab-expansion.html
avg_embed = torch.mean(new_embed[: len(vocab), :], dim=0)
new_embed[len(vocab) :, :] = avg_embed
result[model] = new_embed

return result
Expand Down
32 changes: 31 additions & 1 deletion tests/test_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import tokenizers
import torch
from common import make_picollama, run_and_check_merge
from transformers import LlamaTokenizerFast, PreTrainedTokenizerBase
from transformers import LlamaConfig, LlamaTokenizerFast, PreTrainedTokenizerBase

from mergekit.config import InputModelDefinition, MergeConfiguration
from mergekit.io import LazyTensorLoader
Expand Down Expand Up @@ -270,6 +270,36 @@ def _check_embed(model_path: str):

run_and_check_merge(config, validate=_check_embed)

def test_pad_to_multiple_of(self, model_chatml: str):
config = self.make_config(
[model_chatml],
base_model=model_chatml,
merge_method="linear",
tokenizer_config=TokenizerConfig(
source="base",
pad_to_multiple_of=16,
),
)
real_vocab_size = 64 + 2
padded_size = (real_vocab_size // 16 + 1) * 16

def _check_result(model_path: str):
cfg = LlamaConfig.from_pretrained(model_path)
assert (
cfg.vocab_size == padded_size
), f"Expected vocab size {padded_size}, got {cfg.vocab_size}"
check_tokenizer(
expected_size=real_vocab_size,
must_contain=["<|im_start|>", "<|im_end|>"],
)(model_path)

emb_out = ModelEmbeddings(model_path)
assert (
emb_out.embed_tokens.shape[0] == padded_size
), "Embedding size mismatch"

run_and_check_merge(config, validate=_check_result)

def make_config(
self,
models: List[str],
Expand Down

0 comments on commit 01e60a2

Please sign in to comment.