Skip to content

Commit

Permalink
2x faster inference (#151)
Browse files Browse the repository at this point in the history
* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update llama.py

* Update fast_lora.py

* Update llama.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update swiglu.py

* Update fast_lora.py

* Update swiglu.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update save.py

* Update fast_lora.py

* Update utils.py

* Update llama.py

* Update fast_lora.py

* Update swiglu.py

* Update save.py

* Update save.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Revert "Update llama.py"

This reverts commit a208ec4.

* Update llama.py

* Works?

* Update pyproject.toml

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Swiglu

* Update swiglu.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update swiglu.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* attention_mask

* Update llama.py

* Update llama.py

* labels

* Update mistral.py

* Update llama.py

* attention mask

* Update save.py

* Update save.py

* Update mistral.py

* attention mask

* Update llama.py

* Update llama.py

* Update mistral.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update dpo.py

* Patch saving

* Update save.py

* Update save.py

* patch_saving_functions

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* print

* Mistral patch

* Update mistral.py

* Update save.py

* saving

* Update llama.py

* Update llama.py

* Fast inference repatch

* Update llama.py

* Update utils.py

* Update utils.py

* Update utils.py

* Update mistral.py

* Update __init__.py

* Fix inference

* Update mistral.py

* fast lm_head

* Remove fast path

* Update rope_embedding.py

* Update loader.py

* LlamaAttention_fast_forward_inference

* if past_key_value is not None and q_len == 1:

* revert inference

* Update loader.py

* past_key_value

* Update llama.py

* Update llama.py

* Fix SDPA

* Update llama.py

* padding

* Inference

* Update llama.py

* Revert

* Update mistral.py

* faster inference

* inference

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* inference

* Update llama.py

* Update utils.py

* faster inference

* Update llama.py

* revert

* lm_head

* Update llama.py

* inference

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update mistral.py

* Update llama.py

* faster inference

* Update llama.py

* fast inference

* Update llama.py

* Update llama.py

* Update mistral.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* torch compile

* past_key_values

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update utils.py

* Update utils.py

* Update utils.py

* Update utils.py

* Update llama.py

* fast inference + saving config.json

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update mistral.py

* fast inference again

* more temp matrices

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* fast inference

* Update mistral.py

* Update llama.py

* SDPA

* attention_mask

* New version

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update utils.py

* Update utils.py
  • Loading branch information
danielhanchen authored Feb 4, 2024
1 parent 051a73b commit 35f2ab4
Show file tree
Hide file tree
Showing 7 changed files with 271 additions and 247 deletions.
12 changes: 6 additions & 6 deletions unsloth/kernels/rope_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,9 @@ def forward(ctx, Q, cos, sin, position_ids):
half = Q.shape[-1]//2
RH_Q = torch.cat((-Q[..., half:], Q[..., :half]), dim = -1)
Q *= cos
# Q.addcmul_(RH_Q, sin)
RH_Q *= sin
Q += RH_Q
Q.addcmul_(RH_Q, sin)
# RH_Q *= sin
# Q += RH_Q
ctx.save_for_backward(cos, sin)
return Q
pass
Expand All @@ -148,9 +148,9 @@ def backward(ctx, dY):
half = dY.shape[-1]//2
RH_dY = torch.cat((dY[..., half:], -dY[..., :half]), dim = -1)
dY *= cos
# dY.addcmul_(RH_dY, sin)
RH_dY *= sin
dY += RH_dY
dY.addcmul_(RH_dY, sin)
# RH_dY *= sin
# dY += RH_dY
return dY, None, None, None
pass
pass
Expand Down
26 changes: 16 additions & 10 deletions unsloth/kernels/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def fast_gemv(X, W, quant_state, out = None):
# For fast X @ W where seq_len == 1
# From https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L1469
bsz, q_len, hd = X.shape
assert(q_len == 1)
# assert(q_len == 1)

if type(quant_state) is not list:
# https://github.com/TimDettmers/bitsandbytes/pull/763/files
Expand All @@ -138,7 +138,7 @@ def fast_gemv(X, W, quant_state, out = None):
offset, state2 = compressed_stats
absmax2, code2, blocksize2, _, _, _, _ = state2
pass
assert(dtype == X.dtype)
# assert(dtype == X.dtype)
bout = shape[0]

if out is None:
Expand All @@ -152,7 +152,7 @@ def fast_gemv(X, W, quant_state, out = None):
k = shape[1]
lda = shape[0]
ldc = shape[0]
ldb = (X.shape[-1]+1)//2
ldb = (hd+1)//2
m = ctypes.c_int32(m)
n = ctypes.c_int32(n)
k = ctypes.c_int32(k)
Expand Down Expand Up @@ -192,9 +192,9 @@ def fast_linear_forward(proj, X, temp_lora = None, out = None):
bsz, _, in_dim = X.shape

if W_quant is None:
out = torch.matmul(X, W.t())
elif bsz <= 4:
# Only batches of 4 are faster with Gemv
out = torch.matmul(X, W.t(), out = out)
elif bsz <= 2:
# Only batches of 2 are faster with Gemv
out = fast_gemv(X, W, W_quant, out = out)
else:
W = fast_dequantize(W.t(), W_quant)
Expand All @@ -205,14 +205,20 @@ def fast_linear_forward(proj, X, temp_lora = None, out = None):
if lora_A is not None:
out_dim = out.shape[2]
dtype = X.dtype

if not hasattr(lora_A, "_fast_lora"):
lora_A._fast_lora = lora_A.to(dtype)
lora_B._fast_lora = lora_B.to(dtype)
pass

if bsz == 1:
out = out.view(out_dim)
temp_lora = torch.mv(lora_A.to(dtype), X.ravel(), out = temp_lora)
out.addmv_(lora_B.to(dtype), temp_lora, alpha = lora_S)
temp_lora = torch.mv(lora_A._fast_lora, X.ravel(), out = temp_lora)
out.addmv_(lora_B._fast_lora, temp_lora, alpha = lora_S)
else:
out = out.view(bsz, out_dim)
temp_lora = torch.mm(X.view(bsz, in_dim), lora_A.to(dtype).t(), out = temp_lora)
out.addmm_(temp_lora, lora_B.to(dtype).t(), alpha = lora_S)
temp_lora = torch.mm(X.view(bsz, in_dim), lora_A._fast_lora.t(), out = temp_lora)
out.addmm_(temp_lora, lora_B._fast_lora.t(), alpha = lora_S)
pass
out = out.view(bsz, 1, out_dim)
pass
Expand Down
2 changes: 1 addition & 1 deletion unsloth/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
platform_system = platform_system()
import math

__version__ = "2024.1"
__version__ = "2024.2"

# Get Flash Attention v2 if Ampere (RTX 30xx, A100)
major_version, minor_version = torch.cuda.get_device_capability()
Expand Down
Loading

0 comments on commit 35f2ab4

Please sign in to comment.