Skip to content

Commit

Permalink
qwen2 + starcoder2
Browse files Browse the repository at this point in the history
  • Loading branch information
Cyrilvallez committed Dec 13, 2024
1 parent c4679f2 commit 8b56823
Show file tree
Hide file tree
Showing 25 changed files with 1,469 additions and 845 deletions.
5 changes: 4 additions & 1 deletion examples/modular-transformers/modeling_dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,10 @@ def forward(
past_key_values = DynamicCache()

if cache_position is None:
cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device)
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)

if position_ids is None:
position_ids = cache_position.unsqueeze(0)
Expand Down
5 changes: 4 additions & 1 deletion examples/modular-transformers/modeling_multimodal1.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,10 @@ def forward(
past_key_values = DynamicCache()

if cache_position is None:
cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device)
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)

if position_ids is None:
position_ids = cache_position.unsqueeze(0)
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/integrations/flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ def flash_attention_forward(
else:
seq_len = query.shape[1]

# FA2 uses non-transposed inputs
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)

dropout_rate = config.attention_dropout if training else 0.0

input_dtype = query.dtype
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1490,8 +1490,8 @@ def _autoset_attn_implementation(
message += (
', `"attn_implementation=flex_attention"` (implementation using torch\'s flex_attention)'
)
if cls.model_type + "_"+ config._attn_implementation in ALL_ATTENTION_FUNCTIONS:
config._attn_implementation_internal = cls.model_type+ "_" + config._attn_implementation
if cls.model_type + "_" + config._attn_implementation in ALL_ATTENTION_FUNCTIONS:
config._attn_implementation_internal = cls.model_type + "_" + config._attn_implementation
if config._attn_implementation in ALL_ATTENTION_FUNCTIONS:
pass
else:
Expand Down
5 changes: 4 additions & 1 deletion src/transformers/models/aria/modeling_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -917,7 +917,10 @@ def forward(
past_key_values = DynamicCache()

if cache_position is None:
cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device)
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)

if position_ids is None:
position_ids = cache_position.unsqueeze(0)
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/gemma/modular_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@ def extra_repr(self):

ALL_LAYERNORM_LAYERS.append(GemmaRMSNorm)


class GemmaDecoderLayer(LlamaDecoderLayer):
def __init__(self, config: GemmaConfig, layer_idx: int):
super().__init__(config, layer_idx)
Expand All @@ -355,7 +356,6 @@ def __init__(self, config: GemmaConfig, layer_idx: int):


class GemmaModel(LlamaModel):

def forward(
self,
input_ids: torch.LongTensor = None,
Expand Down
6 changes: 5 additions & 1 deletion src/transformers/models/gemma2/modular_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,8 @@ class Gemma2RotaryEmbedding(GemmaRotaryEmbedding):
pass


def gemma_eager_attention_forward( config: Gemma2Config,
def gemma_eager_attention_forward(
config: Gemma2Config,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
Expand Down Expand Up @@ -344,9 +345,11 @@ def gemma_sdpa_attention_forward(config, query, key, value, mask, **_kwargs):
"gemma_sdpa": gemma_sdpa_attention_forward,
}


class Gemma2Attention(GemmaAttention):
pass


class Gemma2DecoderLayer(GemmaDecoderLayer):
def __init__(self, config: Gemma2Config, layer_idx: int):
super().__init__()
Expand Down Expand Up @@ -417,6 +420,7 @@ def forward(
class Gemma2PreTrainedModel(GemmaPreTrainedModel):
pass


class Gemma2Model(GemmaModel, Gemma2PreTrainedModel):
def __init__(self, config: Gemma2Config):
super().__init__(config)
Expand Down
5 changes: 4 additions & 1 deletion src/transformers/models/glm/modeling_glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,10 @@ def forward(
past_key_values = DynamicCache()

if cache_position is None:
cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device)
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)

if position_ids is None:
position_ids = cache_position.unsqueeze(0)
Expand Down
2 changes: 0 additions & 2 deletions src/transformers/models/glm/modular_glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,6 @@ def __init__(self, config: GlmConfig, layer_idx: Optional[int] = None):
self.scaling = 1 / math.sqrt(self.head_dim)




class GlmDecoderLayer(LlamaDecoderLayer):
def __init__(self, config: GlmConfig, layer_idx: Optional[int] = None):
super().__init__()
Expand Down
19 changes: 7 additions & 12 deletions src/transformers/models/gpt2/modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def eager_attention_forward(self, query, key, value, attention_mask=None, head_m

return attn_output, attn_weights


class GPT2Attention(nn.Module):
def __init__(self, config, is_cross_attention=False, layer_idx=None):
super().__init__()
Expand Down Expand Up @@ -217,8 +218,6 @@ def prune_heads(self, heads):
self.num_heads = self.num_heads - len(heads)
self.pruned_heads = self.pruned_heads.union(heads)



def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None):
# Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM)
bsz, num_heads, q_seq_len, dk = query.size()
Expand Down Expand Up @@ -296,8 +295,6 @@ def forward(
else:
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)



input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)

Expand All @@ -316,17 +313,16 @@ def forward(
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]


if self.reorder_and_upcast_attn:
attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask)
else:
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
**kwargs,
)
self,
query_states,
key_states,
value_states,
**kwargs,
)

attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.c_proj(attn_output)
Expand All @@ -351,7 +347,6 @@ def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.Fl
return hidden_states



class GPT2Block(nn.Module):
def __init__(self, config, layer_idx=None):
super().__init__()
Expand Down
5 changes: 4 additions & 1 deletion src/transformers/models/granite/modeling_granite.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,10 @@ def forward(
past_key_values = DynamicCache()

if cache_position is None:
cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device)
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)

if position_ids is None:
position_ids = cache_position.unsqueeze(0)
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/granite/modular_granite.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,11 @@

logger = logging.get_logger(__name__)


class GraniteRMSNorm(LlamaRMSNorm):
pass


class GraniteMLP(LlamaMLP):
pass

Expand Down
5 changes: 4 additions & 1 deletion src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,10 @@ def forward(
past_key_values = DynamicCache()

if cache_position is None:
cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device)
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)

if position_ids is None:
position_ids = cache_position.unsqueeze(0)
Expand Down
40 changes: 22 additions & 18 deletions src/transformers/models/mistral/modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,27 @@

logger = logging.get_logger(__name__)

_CHECKPOINT_FOR_DOC = "meta-mistral/Mistral-2-7b-hf"

_CHECKPOINT_FOR_DOC = "mistralai/Mistral-7B-v0.1"
_CONFIG_FOR_DOC = "MistralConfig"


class MistralMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]

def forward(self, x):
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj


class MistralRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
Expand Down Expand Up @@ -117,22 +134,6 @@ def forward(self, x, position_ids):
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)


class MistralMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
self.act_fn = ACT2FN[config.hidden_act]

def forward(self, x):
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj


def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
Expand Down Expand Up @@ -509,7 +510,10 @@ def forward(
past_key_values = DynamicCache()

if cache_position is None:
cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device)
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)

if position_ids is None:
position_ids = cache_position.unsqueeze(0)
Expand Down
14 changes: 13 additions & 1 deletion src/transformers/models/mistral/modular_mistral.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,30 @@
import torch
import torch.utils.checkpoint
from torch import nn

from ...cache_utils import Cache, SlidingWindowCache, StaticCache
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ..llama.modeling_llama import (
LlamaForQuestionAnswering,
LlamaForSequenceClassification,
LlamaForTokenClassification,
LlamaMLP,
LlamaModel,
)
from .configuration_mistral import MistralConfig


_CHECKPOINT_FOR_DOC = "mistralai/Mistral-7B-v0.1"


class MistralMLP(LlamaMLP):
def __init__(self, config):
super().__init__(config)
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)


class MistralModel(LlamaModel):
def _update_causal_mask(
self,
Expand Down Expand Up @@ -175,4 +188,3 @@ class MistralForSequenceClassification(LlamaForSequenceClassification):

class MistralForQuestionAnswering(LlamaForQuestionAnswering):
pass

8 changes: 2 additions & 6 deletions src/transformers/models/mixtral/modular_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
logging,
)
from ..llama.modeling_llama import LlamaForCausalLM, LlamaRMSNorm
from ..mistral.modeling_mistral import MistralAttention, MistralModel, MistralPreTrainedModel
from ..mistral.modeling_mistral import MistralAttention, MistralModel
from .configuration_mixtral import MixtralConfig


Expand Down Expand Up @@ -217,6 +217,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
class MixtralRMSNorm(LlamaRMSNorm):
pass


class MixtralAttention(MistralAttention):
pass

Expand Down Expand Up @@ -302,10 +303,7 @@ def forward(
return outputs




class MixtralModel(MistralModel):

def forward(
self,
input_ids: torch.LongTensor = None,
Expand Down Expand Up @@ -442,7 +440,6 @@ def forward(
)



class MixtralForCausalLM(LlamaForCausalLM):
_tied_weights_keys = ["lm_head.weight"]

Expand All @@ -457,7 +454,6 @@ def __init__(self, config):
# Initialize weights and apply final processing
self.post_init()


def forward(
self,
input_ids: torch.LongTensor = None,
Expand Down
5 changes: 4 additions & 1 deletion src/transformers/models/olmo/modeling_olmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,10 @@ def forward(
past_key_values = DynamicCache()

if cache_position is None:
cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device)
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)

if position_ids is None:
position_ids = cache_position.unsqueeze(0)
Expand Down
3 changes: 1 addition & 2 deletions src/transformers/models/olmo/modular_olmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def __init__(self, config: OlmoConfig, layer_idx: int):
self.input_layernorm = OlmoLayerNorm(config.hidden_size)
self.post_attention_layernorm = OlmoLayerNorm(config.hidden_size)


class OlmoForCausalLM(LlamaForCausalLM):
pass

Expand All @@ -118,5 +119,3 @@ class OlmoForSequenceClassification(LlamaForSequenceClassification):

class OlmoForQuestionAnswering(LlamaForQuestionAnswering):
pass


Loading

0 comments on commit 8b56823

Please sign in to comment.