Skip to content

Commit

Permalink
Potentially fix tokenizer merge fp16 on cpu
Browse files Browse the repository at this point in the history
  • Loading branch information
cg123 committed Dec 13, 2023
1 parent 177dac7 commit c37afb2
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion mergekit/merge_methods/tokenizer_permute.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@ def __call__(

x = input_tensors[tr]
p = embed_permutations[tr.model].to(dtype=x.dtype, device=x.device)
temp_dtype = torch.float32 if x.device.type == "cpu" else x.dtype
if p.shape[1] == x.shape[0]:
xp = p @ x
xp = (p.to(dtype=temp_dtype) @ x.to(dtype=temp_dtype)).to(x.dtype)
else:
raise RuntimeError("Shape mismatch")

Expand Down

0 comments on commit c37afb2

Please sign in to comment.