diff --git a/unsloth/kernels/rope_embedding.py b/unsloth/kernels/rope_embedding.py index d7fca30b..a9527520 100644 --- a/unsloth/kernels/rope_embedding.py +++ b/unsloth/kernels/rope_embedding.py @@ -134,9 +134,9 @@ def forward(ctx, Q, cos, sin, position_ids): half = Q.shape[-1]//2 RH_Q = torch.cat((-Q[..., half:], Q[..., :half]), dim = -1) Q *= cos - # Q.addcmul_(RH_Q, sin) - RH_Q *= sin - Q += RH_Q + Q.addcmul_(RH_Q, sin) + # RH_Q *= sin + # Q += RH_Q ctx.save_for_backward(cos, sin) return Q pass @@ -148,9 +148,9 @@ def backward(ctx, dY): half = dY.shape[-1]//2 RH_dY = torch.cat((dY[..., half:], -dY[..., :half]), dim = -1) dY *= cos - # dY.addcmul_(RH_dY, sin) - RH_dY *= sin - dY += RH_dY + dY.addcmul_(RH_dY, sin) + # RH_dY *= sin + # dY += RH_dY return dY, None, None, None pass pass diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index 4f4ce410..2ed2a685 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -119,7 +119,7 @@ def fast_gemv(X, W, quant_state, out = None): # For fast X @ W where seq_len == 1 # From https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L1469 bsz, q_len, hd = X.shape - assert(q_len == 1) + # assert(q_len == 1) if type(quant_state) is not list: # https://github.com/TimDettmers/bitsandbytes/pull/763/files @@ -138,7 +138,7 @@ def fast_gemv(X, W, quant_state, out = None): offset, state2 = compressed_stats absmax2, code2, blocksize2, _, _, _, _ = state2 pass - assert(dtype == X.dtype) + # assert(dtype == X.dtype) bout = shape[0] if out is None: @@ -152,7 +152,7 @@ def fast_gemv(X, W, quant_state, out = None): k = shape[1] lda = shape[0] ldc = shape[0] - ldb = (X.shape[-1]+1)//2 + ldb = (hd+1)//2 m = ctypes.c_int32(m) n = ctypes.c_int32(n) k = ctypes.c_int32(k) @@ -192,9 +192,9 @@ def fast_linear_forward(proj, X, temp_lora = None, out = None): bsz, _, in_dim = X.shape if W_quant is None: - out = torch.matmul(X, W.t()) - elif bsz <= 4: - # Only batches of 4 are faster with Gemv + out = torch.matmul(X, W.t(), out = out) + elif bsz <= 2: + # Only batches of 2 are faster with Gemv out = fast_gemv(X, W, W_quant, out = out) else: W = fast_dequantize(W.t(), W_quant) @@ -205,14 +205,20 @@ def fast_linear_forward(proj, X, temp_lora = None, out = None): if lora_A is not None: out_dim = out.shape[2] dtype = X.dtype + + if not hasattr(lora_A, "_fast_lora"): + lora_A._fast_lora = lora_A.to(dtype) + lora_B._fast_lora = lora_B.to(dtype) + pass + if bsz == 1: out = out.view(out_dim) - temp_lora = torch.mv(lora_A.to(dtype), X.ravel(), out = temp_lora) - out.addmv_(lora_B.to(dtype), temp_lora, alpha = lora_S) + temp_lora = torch.mv(lora_A._fast_lora, X.ravel(), out = temp_lora) + out.addmv_(lora_B._fast_lora, temp_lora, alpha = lora_S) else: out = out.view(bsz, out_dim) - temp_lora = torch.mm(X.view(bsz, in_dim), lora_A.to(dtype).t(), out = temp_lora) - out.addmm_(temp_lora, lora_B.to(dtype).t(), alpha = lora_S) + temp_lora = torch.mm(X.view(bsz, in_dim), lora_A._fast_lora.t(), out = temp_lora) + out.addmm_(temp_lora, lora_B._fast_lora.t(), alpha = lora_S) pass out = out.view(bsz, 1, out_dim) pass diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 5f358514..617b8509 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -23,7 +23,7 @@ platform_system = platform_system() import math -__version__ = "2024.1" +__version__ = "2024.2" # Get Flash Attention v2 if Ampere (RTX 30xx, A100) major_version, minor_version = torch.cuda.get_device_capability() diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 60290052..40e5e56e 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -20,6 +20,9 @@ BaseModelOutputWithPast, CausalLMOutputWithPast, ) +from transformers.modeling_attn_mask_utils import ( + _prepare_4d_causal_attention_mask_for_sdpa, +) from ..kernels import * from ._utils import * from ._utils import __version__ @@ -69,11 +72,14 @@ def original_apply_o(self, X): from math import sqrt as math_sqrt -def _LlamaAttention_fast_forward_inference( +KV_CACHE_INCREMENT = 128 # KV Cache update size + +def LlamaAttention_fast_forward_inference( self, hidden_states: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]], position_ids, + do_prefill = False, ): """ https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L406 @@ -103,146 +109,67 @@ def _LlamaAttention_fast_forward_inference( This means we can pass in a row of Q, but we need to remember K and V, which are called the KV cache. """ + Xn = hidden_states + bsz, _, hd = hidden_states.size() + K1, V1 = past_key_value + dtype = Xn.dtype + n_heads = self.num_heads n_groups = self.num_key_value_groups n_kv_heads = self.num_key_value_heads head_dim = self.head_dim # assert(n_kv_heads * n_groups == n_heads) - - Xn = hidden_states.view(self.hidden_size) - K1, V1 = past_key_value seq_len = K1.shape[-2] - K1 = K1.view(n_kv_heads, seq_len, head_dim) - V1 = V1.view(n_kv_heads, seq_len, head_dim) + kv_seq_len = seq_len + 1 + + # Prefill phase + # if not hasattr(self, "paged_attention"): + if do_prefill: + self.paged_attention = torch.empty((KV_CACHE_INCREMENT+seq_len+1, 2, bsz, n_kv_heads, head_dim), dtype = dtype, device = "cuda") + self.paged_attention_K = self.paged_attention[:,0] + self.paged_attention_V = self.paged_attention[:,1] + self.paged_attention_K[:seq_len] = K1.permute(2, 0, 1, 3) + self.paged_attention_V[:seq_len] = V1.permute(2, 0, 1, 3) + self.temp_QA = torch.empty((2, bsz, 1, hd), dtype = dtype, device = "cuda") + self.temp_KV = torch.empty((2, bsz, 1, n_kv_heads*head_dim), dtype = dtype, device = "cuda") + self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = "cuda") + self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = "cuda") + self.scalar = 1.0 / math_sqrt(self.head_dim) + elif kv_seq_len >= self.paged_attention.shape[0]: + self.paged_attention.resize_((self.paged_attention.shape[0]+KV_CACHE_INCREMENT, 2, bsz, n_kv_heads, head_dim)) + self.paged_attention_K = self.paged_attention[:,0] + self.paged_attention_V = self.paged_attention[:,1] + self.attention.resize_((bsz, n_heads, 1, self.attention.shape[-1]+KV_CACHE_INCREMENT)) + pass + + Qn = fast_linear_forward(self.q_proj, Xn, out = self.temp_QA[0]) + Kn = fast_linear_forward(self.k_proj, Xn, out = self.temp_KV[0]) + Vn = fast_linear_forward(self.v_proj, Xn, out = self.temp_KV[1]) + Qn = Qn.view(bsz, 1, n_heads, head_dim).transpose(1, 2) + Kn = Kn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2) + Vn = Vn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2) - # LoRA or general matrix multiplication - dtype = Xn.dtype - # Qn = self.q_proj(Xn) - # Kn = self.k_proj(Xn) - # Vn = self.v_proj(Xn) - Qn = fast_linear_forward(self.q_proj, Xn) - Kn = fast_linear_forward(self.k_proj, Xn) - Vn = fast_linear_forward(self.v_proj, Xn) - - # Qn = Qn.view(1, 1, n_heads, head_dim).transpose(1, 2) - # Kn = Kn.view(1, 1, n_kv_heads, head_dim).transpose(1, 2) - # Vn = Vn.view(1, 1, n_kv_heads, head_dim).transpose(1, 2) - Qn = Qn.view(n_heads, 1, head_dim) - Kn = Kn.view(n_kv_heads, 1, head_dim) - Vn = Vn.view(n_kv_heads, 1, head_dim) - - # kv_seq_len = K1.shape[-2] + 1 # cos, sin = self.rotary_emb(Vn, seq_len = kv_seq_len) # Qn, Kn = inplace_rope_embedding(Qn, Kn, cos, sin, position_ids) cos = self.rotary_emb.cos_cached[seq_len] sin = self.rotary_emb.sin_cached[seq_len] h = head_dim // 2 - RH_Q = torch.empty((n_heads, 1, head_dim), dtype = dtype, device = "cuda") - RH_Q[:, :, :h] = Qn[:, :, h:]; RH_Q[:, :, h:] = Qn[:, :, :h]; torch.neg(RH_Q[:, :, :h], out = RH_Q[:, :, :h]); + RH_Q = self.RH_Q + RH_Q[:,:,:,:h] = Qn[:,:,:,h:]; RH_Q[:,:,:,h:] = Qn[:,:,:,:h]; torch.neg(RH_Q[:,:,:,:h], out = RH_Q[:,:,:,:h]); Qn *= cos; Qn.addcmul_(RH_Q, sin); - RH_K = RH_Q[:n_kv_heads, :, :] # torch.empty((n_kv_heads, 1, head_dim), dtype = dtype, device = "cuda") - RH_K[:, :, :h] = Kn[:, :, h:]; RH_K[:, :, h:] = Kn[:, :, :h]; torch.neg(RH_K[:, :, :h], out = RH_K[:, :, :h]); + RH_K = RH_Q[:,:n_kv_heads,:,:] # torch.empty((n_kv_heads, 1, head_dim), dtype = dtype, device = "cuda") + RH_K[:,:,:,:h] = Kn[:,:,:,h:]; RH_K[:,:,:,h:] = Kn[:,:,:,:h]; torch.neg(RH_K[:,:,:,:h], out = RH_K[:,:,:,:h]); Kn *= cos; Kn.addcmul_(RH_K, sin); # New KV cache # Kn = torch.cat([K1, Kn], dim = 2) # Vn = torch.cat([V1, Vn], dim = 2) - Kn = torch.cat([K1, Kn], dim = 1) - Vn = torch.cat([V1, Vn], dim = 1) - - # Grouped query attention - if n_groups != 1: - # _, _, cached_len, _ = Kn.shape - # Knn = Kn[:, :, None, :, :].expand(1, n_kv_heads, n_groups, cached_len, head_dim) - # Vnn = Vn[:, :, None, :, :].expand(1, n_kv_heads, n_groups, cached_len, head_dim) - # Knn = Knn.reshape(1, n_heads, cached_len, head_dim) - # Vnn = Vnn.reshape(1, n_heads, cached_len, head_dim) - new_seq_len = seq_len + 1 - Knn = Kn[:, None, :, :].expand(n_kv_heads, n_groups, new_seq_len, head_dim) - Vnn = Vn[:, None, :, :].expand(n_kv_heads, n_groups, new_seq_len, head_dim) - Knn = Knn.reshape(n_heads, new_seq_len, head_dim) - Vnn = Vnn.reshape(n_heads, new_seq_len, head_dim) - else: - Knn, Vnn = Kn, Vn - - # Attention - # A = torch.matmul(Qn, Knn.transpose(2, 3)) - A = torch.matmul(Qn, Knn.transpose(1, 2)) - A *= 1.0 / math_sqrt(self.head_dim) - A[:] = torch.nn.functional.softmax(A, dim = -1, dtype = torch.float32)#.to(A.dtype) - A = torch.matmul(A, Vnn, out = Qn) - # A = A.transpose(1, 2) - A = A.view(self.hidden_size) - - # A = self.o_proj(A) - A = fast_linear_forward(self.o_proj, A) - A = A.reshape(1, 1, self.hidden_size) - - # return A, (Kn, Vn) - return A, (Kn.unsqueeze(0), Vn.unsqueeze(0)) -pass - - -def LlamaAttention_fast_forward_inference( - self, - hidden_states: torch.Tensor, - past_key_value: Optional[Tuple[torch.Tensor]], - position_ids, -): - """ - https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L406 - Fast inference using KV cache. - QK^T can be computed in 4 chunks - - [Q, q] @ [K, k].T where q, k are the new tokens. - [QK^T, Qk^T] - [qK^T, qk^T] - - Since the attention mask wipes Qk^T, we just get - [QK^T, 0] - [qK^T, qk^T] - - Since softmax is row-wise, we get - softmax([QK^T, 0]) - softmax([qK^T, qk^T]) - - We then multiply by [V] - [v] - softmax([QK^T, 0]) [softmax(QK^T)V] * - softmax([qK^T, qk^T]) [softmax([qK^T, qk^T]) @ [V, v]] - - But notice * [softmax(QK^T)V] is just the last attention. - We just need to compute the last final row. - - This means we can pass in a row of Q, but we need to - remember K and V, which are called the KV cache. - """ - Xn = hidden_states - bsz, _, _ = hidden_states.size() - K1, V1 = past_key_value - - n_heads = self.num_heads - n_groups = self.num_key_value_groups - n_kv_heads = self.num_key_value_heads - head_dim = self.head_dim - assert(n_kv_heads * n_groups == n_heads) - - Qn = self.q_proj(Xn) - Kn = self.k_proj(Xn) - Vn = self.v_proj(Xn) - Qn = Qn.view(bsz, 1, n_heads, head_dim).transpose(1, 2) - Kn = Kn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2) - Vn = Vn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2) - - kv_seq_len = K1.shape[-2] + 1 - cos, sin = self.rotary_emb(Vn, seq_len = kv_seq_len) - Qn, Kn = inplace_rope_embedding(Qn, Kn, cos, sin, position_ids) - - # New KV cache - Kn = torch.cat([K1, Kn], dim = 2) - Vn = torch.cat([V1, Vn], dim = 2) + self.paged_attention_K[seq_len] = Kn.permute(2, 0, 1, 3) + self.paged_attention_V[seq_len] = Vn.permute(2, 0, 1, 3) + Kn = self.paged_attention_K[:kv_seq_len].permute(1, 2, 0, 3) + Vn = self.paged_attention_V[:kv_seq_len].permute(1, 2, 0, 3) # Grouped query attention if n_groups != 1: @@ -255,28 +182,31 @@ def LlamaAttention_fast_forward_inference( Knn, Vnn = Kn, Vn # Attention - A = torch.matmul(Qn, Knn.transpose(2, 3)) - A *= 1.0 / (self.head_dim**0.5) - A = torch.nn.functional.softmax(A, dim = -1, dtype = torch.float32).to(A.dtype) - A = torch.matmul(A, Vnn) + A = torch.matmul(Qn, Knn.transpose(2, 3), out = self.attention[:,:,:,:kv_seq_len]) + A *= self.scalar + A[:] = torch.nn.functional.softmax(A, dim = -1, dtype = torch.float32)#.to(A.dtype) + A = torch.matmul(A, Vnn, out = Qn) A = A.transpose(1, 2) A = A.reshape(bsz, 1, self.hidden_size) - A = original_apply_o(self, A) + A = fast_linear_forward(self.o_proj, A, out = self.temp_QA[1]) return A, (Kn, Vn) pass -torch_silu = torch.nn.functional.silu def fast_mlp_inference(self, X): # gate = self.gate_proj(X) # up = self.up_proj(X) - gate = fast_linear_forward(self.gate_proj, X) - up = fast_linear_forward(self. up_proj, X) - gate = torch_silu(gate, inplace = True) + bsz, _, hd = X.shape + mlp_size = self.config.intermediate_size + temp = torch.empty((2, bsz, 1, mlp_size), dtype = X.dtype, device = "cuda") + + gate = fast_linear_forward(self.gate_proj, X, out = temp[0]) + up = fast_linear_forward(self. up_proj, X, out = temp[1]) + gate = torch.nn.functional.silu(gate, inplace = True) gate *= up # X = self.down_proj(gate) - down = fast_linear_forward(self.down_proj, gate) + down = fast_linear_forward(self.down_proj, gate, out = up[:,:,:hd]) return down pass @@ -307,19 +237,19 @@ def LlamaAttention_fast_forward( *args, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - # Check for inference - if past_key_value is not None: - A, past_key_value = LlamaAttention_fast_forward_inference( - self, - hidden_states, - past_key_value, - position_ids, - ) - return A, None, past_key_value + # Clear inference + if hasattr(self, "paged_attention"): + del self.paged_attention_K + del self.paged_attention_V + del self.paged_attention + del self.temp_QA + del self.temp_KV + del self.RH_Q + del self.attention pass + bsz, q_len, _ = hidden_states.size() + n_heads = self.num_heads n_groups = self.num_key_value_groups n_kv_heads = self.num_key_value_heads @@ -351,7 +281,7 @@ def LlamaAttention_fast_forward( past_key_value = (K, V) if use_cache else None # Attention module - if (not HAS_FLASH_ATTENTION): + if (not HAS_FLASH_ATTENTION and attention_mask is None): # Xformers memory efficient attention # Also has Flash Attention v2 dispatching Q = Q.transpose(1, 2) @@ -373,7 +303,7 @@ def LlamaAttention_fast_forward( A = xformers_attention(Q, K, V, attn_bias = causal_mask) A = A.view(bsz, q_len, n_heads, head_dim) - elif HAS_FLASH_ATTENTION: + elif HAS_FLASH_ATTENTION and attention_mask is None: Q = Q.transpose(1, 2) K = K.transpose(1, 2) V = V.transpose(1, 2) @@ -386,11 +316,14 @@ def LlamaAttention_fast_forward( K = K.reshape(bsz, n_heads, kv_seq_len, head_dim) V = V.reshape(bsz, n_heads, kv_seq_len, head_dim) pass + # Must be contiguous or else results are False! + # https://github.com/pytorch/pytorch/issues/112577 + Q, K, V = Q.contiguous(), K.contiguous(), V.contiguous() # Needs (batch_size, n_heads, seq_len, head_dim) # is_casual and attention_mask must not be both set! A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = False) # Go back to (batch_size, seq_len, n_heads, head_dim) - A = A.transpose(1, 2) + A = A.transpose(1, 2).contiguous() pass attn_output = A.reshape(bsz, q_len, self.hidden_size) attn_output = self.apply_o(self, attn_output) @@ -425,20 +358,18 @@ def LlamaDecoderLayer_fast_forward( (see `past_key_values`). past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states """ - bsz, q_len, hd = hidden_states.size() if past_key_value is not None: + do_prefill = not hasattr(self.self_attn, "paged_attention") + # Self Attention residual = hidden_states hidden_states = fast_rms_layernorm_inference(self.input_layernorm, hidden_states) - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - causal_mask=causal_mask, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - padding_mask=padding_mask, + hidden_states, present_key_value = LlamaAttention_fast_forward_inference( + self.self_attn, + hidden_states, + past_key_value, + position_ids, + do_prefill = do_prefill, ) hidden_states += residual @@ -540,7 +471,7 @@ def LlamaModel_fast_forward( pass # We already handle KV cache position_ids ourselves. - if (past_key_values_length != 0): + if False:#(past_key_values_length != 0): position_ids = torch.arange( past_key_values_length, seq_length + past_key_values_length, dtype = torch.int32, @@ -576,17 +507,16 @@ def LlamaModel_fast_forward( # Ignore attention_mask if attention_mask is None: padding_mask = None - elif False: + elif self.training: attention_mask = None padding_mask = None else: - if 0 in attention_mask: - padding_mask = attention_mask - else: - padding_mask = None + # if 0 in attention_mask: + # padding_mask = attention_mask + # else: + padding_mask = None - from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask - attention_mask = _prepare_4d_causal_attention_mask( + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( attention_mask, (batch_size, seq_length), inputs_embeds, @@ -598,11 +528,12 @@ def LlamaModel_fast_forward( hidden_states = inputs_embeds if past_key_values is None and self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "Unsloth: `use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`" - ) - use_cache = False + use_cache = False + # if use_cache: + # logger.warning_once( + # "Unsloth: `use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`" + # ) + # use_cache = False pass # decoder layers @@ -654,13 +585,8 @@ def custom_forward(*inputs): if output_attentions: all_self_attns += (layer_outputs[1],) pass - - bsz, q_len, hd = hidden_states.size() - if past_key_values is not None: - hidden_states = fast_rms_layernorm_inference(self.norm, hidden_states) - else: - hidden_states = fast_rms_layernorm(self.norm, hidden_states) - pass + + hidden_states = fast_rms_layernorm(self.norm, hidden_states) # add hidden states from the last decoder layer if output_hidden_states: @@ -678,6 +604,50 @@ def custom_forward(*inputs): pass +# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L825 +@torch.inference_mode +def LlamaModel_fast_forward_inference( + self, + input_ids, + past_key_values, +): + # Fix out of bounds tokenization + input_ids = input_ids[:,:self.max_seq_length] + + hidden_states = self.embed_tokens(input_ids) + + next_decoder_cache = [] + for idx, decoder_layer in enumerate(self.layers): + # Self Attention + residual = hidden_states + hidden_states = fast_rms_layernorm_inference(decoder_layer.input_layernorm, hidden_states) + hidden_states, present_key_value = LlamaAttention_fast_forward_inference( + decoder_layer.self_attn, + hidden_states, + past_key_values[idx], + None, + ) + hidden_states += residual + + # Fully Connected + residual = hidden_states + hidden_states = fast_rms_layernorm_inference(decoder_layer.post_attention_layernorm, hidden_states) + hidden_states = fast_mlp_inference(decoder_layer.mlp, hidden_states) + hidden_states += residual + + next_decoder_cache.append(present_key_value) + pass + hidden_states = fast_rms_layernorm_inference(self.norm, hidden_states) + + return BaseModelOutputWithPast( + last_hidden_state = hidden_states, + past_key_values = next_decoder_cache, + hidden_states = [], + attentions = [], + ) +pass + + def LlamaForCausalLM_fast_forward( self, input_ids: torch.LongTensor = None, @@ -694,7 +664,7 @@ def LlamaForCausalLM_fast_forward( *args, **kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: - if causal_mask is None: + if causal_mask is None and past_key_values is None: causal_mask = xformers.attn_bias.LowerTriangularMask() output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -705,18 +675,28 @@ def LlamaForCausalLM_fast_forward( # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) self.model._has_no_labels = labels is None - outputs = self.model( - input_ids=input_ids, - causal_mask=causal_mask, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) + + if past_key_values is not None and \ + hasattr(self.model.layers[0].self_attn, "paged_attention"): + outputs = LlamaModel_fast_forward_inference( + self.model, + input_ids, + past_key_values, + ) + else: + outputs = self.model( + input_ids=input_ids, + causal_mask=causal_mask, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + pass hidden_states = outputs[0] bsz, q_len, hd = hidden_states.shape @@ -1228,11 +1208,6 @@ def patch_peft_model( @staticmethod def for_inference(model): - if not hasattr(model, "_original_forward"): - model._original_forward = model.forward - pass - model.forward = torch.inference_mode(model._original_forward) - internal_model = model internal_model.gradient_checkpointing = False internal_model.training = False @@ -1247,10 +1222,6 @@ def for_inference(model): @staticmethod def for_training(model, use_gradient_checkpointing = True): - if hasattr(model, "_original_forward"): - model.forward = model._original_forward - pass - internal_model = model internal_model.gradient_checkpointing = use_gradient_checkpointing internal_model.training = True diff --git a/unsloth/models/mapper.py b/unsloth/models/mapper.py index 56fc5436..c8c73dce 100644 --- a/unsloth/models/mapper.py +++ b/unsloth/models/mapper.py @@ -19,26 +19,26 @@ __INT_TO_FLOAT_MAPPER = \ { - "unsloth/mistral-7b-bnb-4bit" : ( + "unsloth/mistral-7b-bnb-4bit" : ( "unsloth/mistral-7b", "mistralai/Mistral-7B-v0.1", ), - "unsloth/llama-2-7b-bnb-4bit" : ( + "unsloth/llama-2-7b-bnb-4bit" : ( "unsloth/llama-2-7b", "meta-llama/Llama-2-7b-hf", ), - "unsloth/llama-2-13b-bnb-4bit" : ( + "unsloth/llama-2-13b-bnb-4bit" : ( "unsloth/llama-13-7b", "meta-llama/Llama-2-13b-hf", ), "unsloth/codellama-34b-bnb-4bit" : ( "codellama/CodeLlama-34b-hf", ), - "unsloth/zephyr-sft-bnb-4bit" : ( + "unsloth/zephyr-sft-bnb-4bit" : ( "unsloth/zephyr-sft", "HuggingFaceH4/mistral-7b-sft-beta", ), - "unsloth/tinyllama-bnb-4bit" : ( + "unsloth/tinyllama-bnb-4bit" : ( "unsloth/tinyllama", "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T", ), @@ -48,6 +48,28 @@ "unsloth/mistral-7b-instruct-v0.2-bnb-4bit" : ( "mistralai/Mistral-7B-Instruct-v0.2", ), + "unsloth/llama-2-7b-chat-bnb-4bit" : ( + "unsloth/llama-2-7b-chat", + "meta-llama/Llama-2-7b-chat-hf", + ), + "unsloth/llama-2-7b-chat-bnb-4bit" : ( + "unsloth/llama-2-7b-chat", + "meta-llama/Llama-2-7b-chat-hf", + ), + "unsloth/codellama-7b-bnb-4bit" : ( + "unsloth/codellama-7b", + "codellama/CodeLlama-7b-hf", + ), + "unsloth/codellama-13b-bnb-4bit" : ( + "codellama/CodeLlama-13b-hf", + ), + "unsloth/yi-6b-bnb-4bit" : ( + "unsloth/yi-6b", + "01-ai/Yi-6B", + ), + "unsloth/solar-10.7b-bnb-4bit" : ( + "upstage/SOLAR-10.7B-v1.0", + ), } INT_TO_FLOAT_MAPPER = {} diff --git a/unsloth/models/mistral.py b/unsloth/models/mistral.py index 42f26b92..bc00e7a9 100644 --- a/unsloth/models/mistral.py +++ b/unsloth/models/mistral.py @@ -46,19 +46,19 @@ def MistralAttention_fast_forward( *args, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - # Check for inference - if past_key_value is not None: - A, past_key_value = LlamaAttention_fast_forward_inference( - self, - hidden_states, - past_key_value, - position_ids, - ) - return A, None, past_key_value + # Clear inference + if hasattr(self, "paged_attention"): + del self.paged_attention_K + del self.paged_attention_V + del self.paged_attention + del self.temp_QA + del self.temp_KV + del self.RH_Q + del self.attention pass + bsz, q_len, _ = hidden_states.size() + n_heads = self.num_heads n_groups = self.num_key_value_groups n_kv_heads = self.num_key_value_heads @@ -90,7 +90,7 @@ def MistralAttention_fast_forward( past_key_value = (K, V) if use_cache else None # Attention module - if (not HAS_FLASH_ATTENTION): + if (not HAS_FLASH_ATTENTION and attention_mask is None): # Xformers memory efficient attention Q = Q.transpose(1, 2) K = K.transpose(1, 2) @@ -128,7 +128,7 @@ def MistralAttention_fast_forward( A = xformers_attention(Q, K, V, attn_bias = causal_mask) A = A.view(bsz, q_len, n_heads, head_dim) - elif HAS_FLASH_ATTENTION: + elif HAS_FLASH_ATTENTION and attention_mask is None: Q = Q.transpose(1, 2) K = K.transpose(1, 2) V = V.transpose(1, 2) @@ -144,11 +144,14 @@ def MistralAttention_fast_forward( K = K.reshape(bsz, n_heads, kv_seq_len, head_dim) V = V.reshape(bsz, n_heads, kv_seq_len, head_dim) # pass + # Must be contiguous or else results are False! + # https://github.com/pytorch/pytorch/issues/112577 + Q, K, V = Q.contiguous(), K.contiguous(), V.contiguous() # Needs (batch_size, n_heads, seq_len, head_dim) # is_casual and attention_mask must not be both set! A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = False) # Go back to (batch_size, seq_len, n_heads, head_dim) - A = A.transpose(1, 2) + A = A.transpose(1, 2).contiguous() pass attn_output = A.reshape(bsz, q_len, self.hidden_size) @@ -174,7 +177,7 @@ def MistralForCausalLM_fast_forward( *args, **kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: - if causal_mask is None: + if causal_mask is None and past_key_values is None: bsz, q_len = input_ids.shape sliding_window = getattr(self.config, "sliding_window", None) if sliding_window is None or sliding_window == "null" or sliding_window <= 0: @@ -196,18 +199,28 @@ def MistralForCausalLM_fast_forward( # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) self.model._has_no_labels = labels is None - outputs = self.model( - input_ids=input_ids, - causal_mask=causal_mask, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) + + if past_key_values is not None and \ + hasattr(self.model.layers[0].self_attn, "paged_attention"): + outputs = LlamaModel_fast_forward_inference( + self.model, + input_ids, + past_key_values, + ) + else: + outputs = self.model( + input_ids=input_ids, + causal_mask=causal_mask, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + pass hidden_states = outputs[0] bsz, q_len, hd = hidden_states.shape diff --git a/unsloth/save.py b/unsloth/save.py index 4cda8b96..ae0f97a8 100644 --- a/unsloth/save.py +++ b/unsloth/save.py @@ -744,7 +744,6 @@ def unsloth_push_to_hub_merged( [](https://github.com/unslothai/unsloth) """ - def upload_to_huggingface(model, save_directory, token, method, extra = "", file_location = None): # Check for username username = "" @@ -797,6 +796,19 @@ def upload_to_huggingface(model, save_directory, token, method, extra = "", file repo_id = save_directory, repo_type = "model", ) + + # We also upload a config.json file + import json + with open("_temporary_unsloth_config.json", "w") as file: + json.dump({"model_type" : model.config.model_type}, file, indent = 4) + pass + hf_api.upload_file( + path_or_fileobj = "_temporary_unsloth_config.json", + path_in_repo = "config.json", + repo_id = save_directory, + repo_type = "model", + ) + os.remove("_temporary_unsloth_config.json") pass return username pass