Skip to content

Commit

Permalink
Several fixes related to rotary position embeddings
Browse files Browse the repository at this point in the history
First part of resolution of huggingface#35233
- Changes related to `position_embeddings` being a mandatory argument
- Remove `position_ids` argument of `apply_rotary_pos_emb`
- Replace `torch.stack` by `torch.cat`, former requires equal shapes
- `esm`: RoPE depends on `position_ids`, which was ignored.
- `gpt_neox`: Selection of attention compute type via class removed
- `gptj`: RoPE must be applied per head, and some shape issues.
- `nemotron`: `config.partial_rotary_factor` was not implemented.
  • Loading branch information
mseeger committed Dec 20, 2024
1 parent 608e163 commit d8f70ce
Show file tree
Hide file tree
Showing 97 changed files with 1,725 additions and 428 deletions.
5 changes: 3 additions & 2 deletions src/transformers/modeling_rope_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,8 @@ def _compute_dynamic_ntk_parameters(
seq_len = seq_len if seq_len is not None and seq_len > max_position_embeddings else max_position_embeddings

# Compute the inverse frequencies
base = base * ((factor * seq_len / max_position_embeddings) - (factor - 1)) ** (dim / (dim - 2))
if dim != 2:
base = base * ((factor * seq_len / max_position_embeddings) - (factor - 1)) ** (dim / (dim - 2))
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim))
return inv_freq, attention_factor

Expand Down Expand Up @@ -230,7 +231,7 @@ def linear_ramp_factor(min, max, dim):
low, high = find_correction_range(beta_fast, beta_slow, dim, base, max_position_embeddings)

# Get n-dimensional rotational scaling corrected for extrapolation
inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).float().to(device)
inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, pos_freqs.shape[0]).float().to(device)
inv_freq = (
inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)
+ inv_freq_extrapolation * inv_freq_extrapolation_factor
Expand Down
17 changes: 10 additions & 7 deletions src/transformers/models/aria/modeling_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,16 +437,14 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
Expand Down Expand Up @@ -537,6 +535,8 @@ def forward(
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if position_embeddings is None:
raise ValueError("position_embeddings = (cos, sin) must be given")
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)

Expand Down Expand Up @@ -603,13 +603,13 @@ def __init__(self, config: AriaTextConfig, layer_idx: int):
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
Expand All @@ -619,13 +619,13 @@ def forward(
# Self Attention
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = residual + hidden_states
Expand Down Expand Up @@ -777,6 +777,9 @@ def forward(self, x, position_ids):
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
# Happens if self.self.dim is odd
if emb.shape[-1] > self.dim:
emb = emb[..., :self.dim]
cos = emb.cos()
sin = emb.sin()

Expand Down Expand Up @@ -963,24 +966,24 @@ def forward(
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
position_embeddings,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
hidden_states,
position_embeddings=position_embeddings,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)

Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/bamba/modeling_bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,8 @@ def forward(
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if position_embeddings is None:
raise ValueError("position_embeddings = (cos, sin) must be given")
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)

Expand Down
7 changes: 4 additions & 3 deletions src/transformers/models/chameleon/modeling_chameleon.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ def forward(self, x, position_ids):
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
# Happens if self.dim is odd
if emb.shape[-1] > self.dim:
emb = emb[..., :self.dim]
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
Expand Down Expand Up @@ -153,16 +156,14 @@ def rotate_half(x):


# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
Expand Down
13 changes: 8 additions & 5 deletions src/transformers/models/clvp/modeling_clvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,8 +251,8 @@ class ClvpRotaryPositionalEmbedding(nn.Module):

def __init__(self, config):
super().__init__()
dim = max(config.projection_dim // (config.num_attention_heads * 2), 32)
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
self.dim = max(config.projection_dim // (config.num_attention_heads * 2), 32)
inv_freq = 1.0 / (10000 ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))

self.register_buffer("inv_freq", inv_freq)
self.cached_sequence_length = None
Expand All @@ -267,9 +267,12 @@ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
self.cached_sequence_length = sequence_length
time_stamps = torch.arange(sequence_length, device=hidden_states.device).type_as(self.inv_freq)
freqs = torch.einsum("i,j->ij", time_stamps, self.inv_freq)
embeddings = torch.cat((freqs, freqs), dim=-1)
emb = torch.cat((freqs, freqs), dim=-1)
# Happens if self.dim is odd
if emb.shape[-1] > self.dim:
emb = emb[..., :self.dim]

self.cached_rotary_positional_embedding = embeddings.unsqueeze(0)
self.cached_rotary_positional_embedding = emb.unsqueeze(0)
return self.cached_rotary_positional_embedding


Expand Down Expand Up @@ -313,7 +316,7 @@ def forward(
rotary_pos_emb: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: Optional[bool] = False,
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
Expand Down
8 changes: 6 additions & 2 deletions src/transformers/models/codegen/modeling_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,18 @@ def create_sinusoidal_positions(num_pos: int, dim: int) -> torch.Tensor:
def rotate_every_two(x: torch.Tensor) -> torch.Tensor:
x1 = x[:, :, :, ::2]
x2 = x[:, :, :, 1::2]
x = torch.stack((-x2, x1), dim=-1)
return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')
return torch.concat((-x2, x1), dim=-1)


# Copied from transformers.models.gptj.modeling_gptj.apply_rotary_pos_emb
def apply_rotary_pos_emb(tensor: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor) -> torch.Tensor:
sin = torch.repeat_interleave(sin[:, :, None, :], 2, 3)
cos = torch.repeat_interleave(cos[:, :, None, :], 2, 3)
emb_size = tensor.shape[-1]
# Happens if emb_size is odd
if cos.shape[-1] > emb_size:
cos = cos[..., :emb_size]
sin = sin[..., :emb_size]
return (tensor * cos) + (rotate_every_two(tensor) * sin)


Expand Down
30 changes: 20 additions & 10 deletions src/transformers/models/cohere/modeling_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def __init__(
self.rope_type = rope_type
self.max_seq_len_cached = max_position_embeddings
self.original_max_seq_len = max_position_embeddings
self.rotary_ndims = None
else:
# BC: "rope_type" was originally "type"
if config.rope_scaling is not None:
Expand All @@ -122,6 +123,7 @@ def __init__(
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.rotary_ndims = config.hidden_size // config.num_attention_heads

self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
Expand Down Expand Up @@ -162,6 +164,9 @@ def forward(self, x, position_ids):
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.repeat_interleave(freqs, 2, dim=-1) # This line differs from Llama's implementation
# Happens if self.rotary_ndims is odd
if self.config is not None and emb.shape[-1] > self.rotary_ndims:
emb = emb[..., :self.rotary_ndims]
cos = emb.cos()
sin = emb.sin()

Expand All @@ -176,8 +181,7 @@ def rotate_half(x):
# Split and rotate. Note that this function is different from e.g. Llama.
x1 = x[..., ::2]
x2 = x[..., 1::2]
rot_x = torch.stack([-x2, x1], dim=-1).flatten(-2)
return rot_x
return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
Expand Down Expand Up @@ -286,15 +290,17 @@ def __init__(self, config: CohereConfig, layer_idx: Optional[int] = None):
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if position_embeddings is None:
raise ValueError("position_embeddings = (cos, sin) must be given")
bsz, q_len, _ = hidden_states.size()

query_states = self.q_proj(hidden_states)
Expand Down Expand Up @@ -372,15 +378,17 @@ def __init__(self, *args, **kwargs):
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if position_embeddings is None:
raise ValueError("position_embeddings = (cos, sin) must be given")
if isinstance(past_key_value, StaticCache):
raise ValueError(
"`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
Expand Down Expand Up @@ -476,14 +484,16 @@ class CohereSdpaAttention(CohereAttention):
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if position_embeddings is None:
raise ValueError("position_embeddings = (cos, sin) must be given")
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
logger.warning_once(
Expand All @@ -492,13 +502,13 @@ def forward(
)
return super().forward(
hidden_states=hidden_states,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)

bsz, q_len, _ = hidden_states.size()
Expand Down Expand Up @@ -581,13 +591,13 @@ def __init__(self, config: CohereConfig, layer_idx: int):
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
Expand Down Expand Up @@ -615,13 +625,13 @@ def forward(
# Self Attention
hidden_states_attention, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)

# Fully Connected
Expand Down Expand Up @@ -861,24 +871,24 @@ def forward(
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
position_embeddings,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
hidden_states,
position_embeddings=position_embeddings,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)

Expand Down
3 changes: 1 addition & 2 deletions src/transformers/models/cohere2/modeling_cohere2.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,7 @@ def rotate_half(x):
# Split and rotate. Note that this function is different from e.g. Llama.
x1 = x[..., ::2]
x2 = x[..., 1::2]
rot_x = torch.stack([-x2, x1], dim=-1).flatten(-2)
return rot_x
return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
Expand Down
Loading

0 comments on commit d8f70ce

Please sign in to comment.