Skip to content

Commit

Permalink
fix sliding window models
Browse files Browse the repository at this point in the history
  • Loading branch information
Cyrilvallez committed Dec 16, 2024
1 parent e0d10f6 commit b275fdc
Show file tree
Hide file tree
Showing 23 changed files with 539 additions and 232 deletions.
29 changes: 22 additions & 7 deletions examples/modular-transformers/modeling_dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,20 +171,32 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


def eager_attention_forward(attention_class: nn.Module, query, key, value, attention_mask=None, **_kwargs):
config = attention_class.config
key_states = repeat_kv(key, attention_class.num_key_value_groups)
value_states = repeat_kv(value, attention_class.num_key_value_groups)

attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * attention_class.scaling
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
dropout: float = 0.0,
scaling: Optional[float] = None,
**kwargs,
):
if scaling is None:
scaling = module.head_dim**-0.5

key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)

attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask

attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=attention_class.training)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()

return attn_output, attn_weights


Expand All @@ -198,6 +210,7 @@ def __init__(self, config: DummyConfig, layer_idx: int):
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.scaling = self.head_dim**-0.5
self.attention_droupout = config.attention_dropout
self.is_causal = True

self.q_proj = nn.Linear(
Expand Down Expand Up @@ -251,6 +264,8 @@ def forward(
query_states,
key_states,
value_states,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
**kwargs,
)

Expand Down
29 changes: 22 additions & 7 deletions examples/modular-transformers/modeling_multimodal1.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,20 +171,32 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


def eager_attention_forward(attention_class: nn.Module, query, key, value, attention_mask=None, **_kwargs):
config = attention_class.config
key_states = repeat_kv(key, attention_class.num_key_value_groups)
value_states = repeat_kv(value, attention_class.num_key_value_groups)

attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * attention_class.scaling
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
dropout: float = 0.0,
scaling: Optional[float] = None,
**kwargs,
):
if scaling is None:
scaling = module.head_dim**-0.5

key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)

attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask

attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=attention_class.training)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()

return attn_output, attn_weights


Expand All @@ -198,6 +210,7 @@ def __init__(self, config: Multimodal1TextConfig, layer_idx: int):
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.scaling = self.head_dim**-0.5
self.attention_droupout = config.attention_dropout
self.is_causal = True

self.q_proj = nn.Linear(
Expand Down Expand Up @@ -251,6 +264,8 @@ def forward(
query_states,
key_states,
value_states,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
**kwargs,
)

Expand Down
29 changes: 22 additions & 7 deletions examples/modular-transformers/modeling_my_new_model2.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,20 +171,32 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


def eager_attention_forward(attention_class: nn.Module, query, key, value, attention_mask=None, **_kwargs):
config = attention_class.config
key_states = repeat_kv(key, attention_class.num_key_value_groups)
value_states = repeat_kv(value, attention_class.num_key_value_groups)

attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * attention_class.scaling
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
dropout: float = 0.0,
scaling: Optional[float] = None,
**kwargs,
):
if scaling is None:
scaling = module.head_dim**-0.5

key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)

attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask

attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=attention_class.training)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()

return attn_output, attn_weights


Expand All @@ -198,6 +210,7 @@ def __init__(self, config: MyNewModel2Config, layer_idx: int):
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.scaling = self.head_dim**-0.5
self.attention_droupout = config.attention_dropout
self.is_causal = True

self.q_proj = nn.Linear(
Expand Down Expand Up @@ -251,6 +264,8 @@ def forward(
query_states,
key_states,
value_states,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
**kwargs,
)

Expand Down
29 changes: 22 additions & 7 deletions examples/modular-transformers/modeling_super.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,20 +171,32 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


def eager_attention_forward(attention_class: nn.Module, query, key, value, attention_mask=None, **_kwargs):
config = attention_class.config
key_states = repeat_kv(key, attention_class.num_key_value_groups)
value_states = repeat_kv(value, attention_class.num_key_value_groups)

attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * attention_class.scaling
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
dropout: float = 0.0,
scaling: Optional[float] = None,
**kwargs,
):
if scaling is None:
scaling = module.head_dim**-0.5

key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)

attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask

attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=attention_class.training)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()

return attn_output, attn_weights


Expand All @@ -198,6 +210,7 @@ def __init__(self, config: SuperConfig, layer_idx: int):
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.scaling = self.head_dim**-0.5
self.attention_droupout = config.attention_dropout
self.is_causal = True

self.q_proj = nn.Linear(
Expand Down Expand Up @@ -251,6 +264,8 @@ def forward(
query_states,
key_states,
value_states,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
**kwargs,
)

Expand Down
29 changes: 22 additions & 7 deletions src/transformers/models/aria/modeling_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,20 +476,32 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


def eager_attention_forward(attention_class: nn.Module, query, key, value, attention_mask=None, **_kwargs):
config = attention_class.config
key_states = repeat_kv(key, attention_class.num_key_value_groups)
value_states = repeat_kv(value, attention_class.num_key_value_groups)

attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * attention_class.scaling
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
dropout: float = 0.0,
scaling: Optional[float] = None,
**kwargs,
):
if scaling is None:
scaling = module.head_dim**-0.5

key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)

attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask

attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=attention_class.training)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()

return attn_output, attn_weights


Expand All @@ -503,6 +515,7 @@ def __init__(self, config: AriaTextConfig, layer_idx: int):
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.scaling = self.head_dim**-0.5
self.attention_droupout = config.attention_dropout
self.is_causal = True

self.q_proj = nn.Linear(
Expand Down Expand Up @@ -556,6 +569,8 @@ def forward(
query_states,
key_states,
value_states,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
**kwargs,
)

Expand Down
29 changes: 22 additions & 7 deletions src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,20 +202,32 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


def eager_attention_forward(attention_class: nn.Module, query, key, value, attention_mask=None, **_kwargs):
config = attention_class.config
key_states = repeat_kv(key, attention_class.num_key_value_groups)
value_states = repeat_kv(value, attention_class.num_key_value_groups)

attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * attention_class.scaling
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
dropout: float = 0.0,
scaling: Optional[float] = None,
**kwargs,
):
if scaling is None:
scaling = module.head_dim**-0.5

key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)

attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask

attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=attention_class.training)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()

return attn_output, attn_weights


Expand All @@ -229,6 +241,7 @@ def __init__(self, config: GemmaConfig, layer_idx: int):
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.scaling = self.head_dim**-0.5
self.attention_droupout = config.attention_dropout
self.is_causal = True

self.q_proj = nn.Linear(
Expand Down Expand Up @@ -282,6 +295,8 @@ def forward(
query_states,
key_states,
value_states,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
**kwargs,
)

Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def __init__(self, config: Gemma2Config, layer_idx: int):
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.scaling = config.query_pre_attn_scalar**-0.5
self.attention_droupout = config.attention_dropout
self.is_causal = True

self.q_proj = nn.Linear(
Expand Down
Loading

0 comments on commit b275fdc

Please sign in to comment.