From f2a49e8953b315c0164d38f1141dd012c9980879 Mon Sep 17 00:00:00 2001 From: Charles Goddard Date: Thu, 4 Jan 2024 14:40:16 -0800 Subject: [PATCH] Use idx->idx dict instead of full permutation matrix --- mergekit/merge_methods/tokenizer_permute.py | 24 +++++---- mergekit/tokenizer.py | 56 +++++++++++++++------ 2 files changed, 56 insertions(+), 24 deletions(-) diff --git a/mergekit/merge_methods/tokenizer_permute.py b/mergekit/merge_methods/tokenizer_permute.py index ce154ed4..31871694 100644 --- a/mergekit/merge_methods/tokenizer_permute.py +++ b/mergekit/merge_methods/tokenizer_permute.py @@ -28,7 +28,7 @@ class TokenizerPermutationMerge(MergeMethod): def __call__( self, input_tensors: Dict[TensorReference, torch.Tensor], - embed_permutations: Dict[ModelReference, torch.IntTensor], + embed_permutations: Dict[ModelReference, Dict[int, int]], config: ConfigReader, **_kwargs, ) -> torch.Tensor: @@ -47,15 +47,20 @@ def __call__( models.append(tr.model) x = input_tensors[tr] - p = embed_permutations[tr.model].to(dtype=x.dtype, device=x.device) - temp_dtype = torch.float32 if x.device.type == "cpu" else x.dtype - if p.shape[1] == x.shape[0]: - xp = (p.to(dtype=temp_dtype) @ x.to(dtype=temp_dtype)).to(x.dtype) - else: - raise RuntimeError("Shape mismatch") + p = embed_permutations[tr.model] + + xp = torch.zeros((len(p), x.shape[-1]), dtype=x.dtype, device=x.device) + mask = torch.zeros((len(p),), dtype=torch.bool, device=x.device) + for out_idx in p: + in_idx = p[out_idx] + if in_idx < 0: + continue + + xp[out_idx, :] = x[in_idx, :] + mask[out_idx] = 1 expanded.append(xp) - masks.append(p.sum(dim=-1, keepdim=True) > 0) + masks.append(mask) is_base = tr.model == config.base_model if use_slerp: @@ -63,11 +68,10 @@ def __call__( weight = (1.0 - t) if is_base else t else: weight = config.parameter("weight", model=tr.model, default=1.0) - weights.append(weight) expanded = torch.stack(expanded, dim=0) - masks = torch.stack(masks, dim=0) + masks = torch.stack(masks, dim=0).unsqueeze(-1) weights = ( torch.tensor(weights, dtype=expanded.dtype, device=expanded.device) .unsqueeze(-1) diff --git a/mergekit/tokenizer.py b/mergekit/tokenizer.py index 934ee928..bd2599b5 100644 --- a/mergekit/tokenizer.py +++ b/mergekit/tokenizer.py @@ -15,6 +15,7 @@ import json import logging +import tempfile from typing import Dict, Optional, Tuple import tokenizers @@ -42,15 +43,23 @@ def get_vocab_size(model_path: str, trust_remote_code: bool) -> Optional[int]: def get_stripped_tokenizer( path: str, trust_remote_code: bool = False ) -> transformers.PreTrainedTokenizerFast: + """ + Return a tokenizer for a model that only contains used tokens. + + Strips any tokens with indices >= model.vocab_size. + """ tokenizer = transformers.AutoTokenizer.from_pretrained( path, trust_remote_code=trust_remote_code, use_fast=True ) - vocab_size = get_vocab_size(path) or len(tokenizer.get_vocab()) + vocab_size = get_vocab_size(path, trust_remote_code=trust_remote_code) or len( + tokenizer.get_vocab() + ) unused_toks = [ tok for tok, idx in tokenizer.get_vocab().items() if idx >= vocab_size ] if not unused_toks: + # we're good, ship it return tokenizer if not tokenizer.is_fast: @@ -91,12 +100,18 @@ def _keep_merge(m): def build_union_tokenizer( base_tok: transformers.PreTrainedTokenizerBase, tokenizers: Dict[ModelReference, transformers.PreTrainedTokenizerBase], + trust_remote_code: bool = False, ) -> transformers.PreTrainedTokenizerBase: out_added_tokens = {} out_vocab = {} + warned_added_tokens = set() + for model, tokenizer in tokenizers.items(): - vocab_size = get_vocab_size(model) or tokenizer.vocab_size + vocab_size = ( + get_vocab_size(model, trust_remote_code=trust_remote_code) + or tokenizer.vocab_size + ) added_tokens = tokenizer.added_tokens_decoder vocab = tokenizer.get_vocab() @@ -115,14 +130,22 @@ def build_union_tokenizer( for tok, info in tokenizer.added_tokens_decoder.items(): if tok in out_added_tokens: - if out_added_tokens[tok] != info: + if (out_added_tokens[tok] != info) and tok not in warned_added_tokens: logging.warning( f"Token '{tok}' added with multiple different settings, using first" ) + warned_added_tokens.add(tok) + continue out_added_tokens[tok] = info - res = base_tok + # HACK: save base tokenizer to temp dir and reload to avoid mutating base_tok + with tempfile.TemporaryDirectory() as p: + base_tok.save_pretrained(p, legacy_format=False, safe_serialization=True) + res = transformers.AutoTokenizer.from_pretrained( + p, use_fast=True, trust_remote_code=trust_remote_code + ) + orig_base_vocab = base_tok.get_vocab() for tok in out_vocab: if tok in out_added_tokens: @@ -148,19 +171,20 @@ def build_tokenizer( if base_model is None: raise RuntimeError("No models referenced") - tokenizer_out = get_stripped_tokenizer( + # + tokenizer_base = get_stripped_tokenizer( base_model.path, trust_remote_code=trust_remote_code ) # load all tokenizers logging.info("Loading tokenizers") - tokenizers = {base_model: tokenizer_out} + tokenizers = {base_model: tokenizer_base} for model in config.referenced_models(): if model == base_model: continue try: - model_tok = get_stripped_tokenizer( + model_tok = transformers.AutoTokenizer.from_pretrained( model.path, trust_remote_code=trust_remote_code ) except Exception: @@ -174,9 +198,11 @@ def build_tokenizer( # build final vocabulary if config.tokenizer_source == "base": # it done - pass + tokenizer_out = tokenizer_base elif config.tokenizer_source == "union": - tokenizer_out = build_union_tokenizer(tokenizer_out, tokenizers) + tokenizer_out = build_union_tokenizer( + tokenizer_base, tokenizers, trust_remote_code=trust_remote_code + ) elif config.tokenizer_source.startswith("model:"): tokenizer_out = transformers.AutoTokenizer.from_pretrained( config.tokenizer_source.removeprefix("model:"), @@ -199,9 +225,11 @@ def build_tokenizer( if vocab_size is None: vocab_size = len(model_vocab) - p = torch.zeros(len(vocab_out), vocab_size, dtype=torch.int32) - for tok in model_vocab: - if tok not in vocab_out: + p = {} + for tok in vocab_out: + new_idx = vocab_out[tok] + if tok not in model_vocab: + p[new_idx] = -1 continue orig_idx = model_vocab[tok] @@ -211,8 +239,8 @@ def build_tokenizer( ) continue - new_idx = vocab_out[tok] - p[new_idx, orig_idx] = 1 + p[new_idx] = orig_idx + permutations[model] = p return tokenizer_out, permutations