Skip to content

Commit

Permalink
Several fixes related to rotary position embeddings
Browse files Browse the repository at this point in the history
First part of resolution of #35233
- Changes related to `position_embeddings` being a mandatory argument
- Remove `position_ids` argument of `apply_rotary_pos_emb`
- Replace `torch.stack` by `torch.cat`, former requires equal shapes
- `esm`: RoPE depends on `position_ids`, which was ignored.
- `gpt_neox`: Selection of attention compute type via class removed
- `gptj`, `codegen`: RoPE must be applied per head, and some shape issues.
- `nemotron`: `config.partial_rotary_factor` was not implemented.
  • Loading branch information
mseeger committed Dec 29, 2024
1 parent 5c75087 commit 2901c93
Show file tree
Hide file tree
Showing 65 changed files with 645 additions and 512 deletions.
14 changes: 7 additions & 7 deletions src/transformers/models/aria/modeling_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,16 +437,14 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
Expand Down Expand Up @@ -537,6 +535,8 @@ def forward(
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if position_embeddings is None:
raise ValueError("position_embeddings = (cos, sin) must be given")
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)

Expand Down Expand Up @@ -603,13 +603,13 @@ def __init__(self, config: AriaTextConfig, layer_idx: int):
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
Expand All @@ -619,13 +619,13 @@ def forward(
# Self Attention
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = residual + hidden_states
Expand Down Expand Up @@ -963,24 +963,24 @@ def forward(
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
position_embeddings,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
hidden_states,
position_embeddings=position_embeddings,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)

Expand Down
6 changes: 3 additions & 3 deletions src/transformers/models/bamba/modeling_bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def eager_attention_forward(


# Adapted from transformers.models.glm.modular_glm.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Removes the interleaving of cos and sin from GLM
Expand All @@ -240,8 +240,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
Expand Down Expand Up @@ -305,6 +303,8 @@ def forward(
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if position_embeddings is None:
raise ValueError("position_embeddings = (cos, sin) must be given")
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)

Expand Down
4 changes: 1 addition & 3 deletions src/transformers/models/bamba/modular_bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ class BambaRotaryEmbedding(LlamaRotaryEmbedding):


# Adapted from transformers.models.glm.modular_glm.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Removes the interleaving of cos and sin from GLM
Expand All @@ -154,8 +154,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
Expand Down
6 changes: 2 additions & 4 deletions src/transformers/models/chameleon/modeling_chameleon.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,16 +153,14 @@ def rotate_half(x):


# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
Expand Down Expand Up @@ -532,7 +530,7 @@ def forward(
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

if past_key_value is not None:
# sin and cos are specific to RoPE models; position_ids needed for the static cache
Expand Down
60 changes: 36 additions & 24 deletions src/transformers/models/codegen/modeling_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,21 +41,22 @@
def create_sinusoidal_positions(num_pos: int, dim: int) -> torch.Tensor:
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64) / dim))
sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(num_pos, dtype=torch.int64).float(), inv_freq).float()
return torch.cat((torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)), dim=1)
sin, cos = torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)
out = torch.cat((sin, cos), dim=1)
return out


# Copied from transformers.models.gptj.modeling_gptj.rotate_every_two
def rotate_every_two(x: torch.Tensor) -> torch.Tensor:
x1 = x[:, :, :, ::2]
x2 = x[:, :, :, 1::2]
x = torch.stack((-x2, x1), dim=-1)
return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')
return torch.concat((-x2, x1), dim=-1)


# Copied from transformers.models.gptj.modeling_gptj.apply_rotary_pos_emb
def apply_rotary_pos_emb(tensor: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor) -> torch.Tensor:
sin = torch.repeat_interleave(sin[:, :, None, :], 2, 3)
cos = torch.repeat_interleave(cos[:, :, None, :], 2, 3)
sin = torch.repeat_interleave(sin, 2, dim=-1)
cos = torch.repeat_interleave(cos, 2, dim=-1)
return (tensor * cos) + (rotate_every_two(tensor) * sin)


Expand Down Expand Up @@ -87,25 +88,24 @@ def __init__(self, config, layer_idx=None):

self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
self.rotary_dim = config.rotary_dim
pos_embd_dim = self.rotary_dim or self.embed_dim
pos_embd_dim = self.rotary_dim or self.head_dim
# `embed_positions` of shape `(max_positions, 2 * pos_embd_dim)`
self.embed_positions = create_sinusoidal_positions(max_positions, pos_embd_dim)

# TODO: Add comment on the role of mp_num. Why this complex reshaping?
def _split_heads(self, x, n_head, dim_head, mp_num):
reshaped = x.reshape(x.shape[:-1] + (n_head // mp_num, dim_head))
reshaped = reshaped.reshape(x.shape[:-2] + (-1,) + reshaped.shape[-1:])
return reshaped

def _merge_heads(self, tensor, num_attention_heads, attn_head_size):
def _merge_heads(self, tensor: torch.Tensor) -> torch.Tensor:
"""
Merges attn_head_size dim and num_attn_heads dim into n_ctx
"""
if len(tensor.shape) == 5:
tensor = tensor.permute(0, 1, 3, 2, 4).contiguous()
elif len(tensor.shape) == 4:
tensor = tensor.permute(0, 2, 1, 3).contiguous()
else:
raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}")
new_shape = tensor.size()[:-2] + (num_attention_heads * attn_head_size,)
if not (4 <= tensor.dim() <= 5):
raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {tensor.dim()}")
tensor = tensor.transpose(-2, -3).contiguous()
new_shape = tensor.size()[:-2] + (self.num_attention_heads * self.head_dim,)
return tensor.view(new_shape)

def _attn(
Expand Down Expand Up @@ -153,33 +153,44 @@ def forward(
Tuple[torch.Tensor, Tuple[torch.Tensor]],
Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]],
]:
qkv = self.qkv_proj(hidden_states)
if position_ids is None:
raise ValueError("position_ids must be given")
qkv = self.qkv_proj(hidden_states) # (B, T, 3 * n_head * head_dim)
# TODO(enijkamp): factor out number of logical TPU-v4 cores or make forward pass agnostic
mp_num = 4
qkv_split = qkv.reshape(qkv.shape[:-1] + (mp_num, -1))

local_dim = self.head_dim * self.num_attention_heads // mp_num
query, value, key = torch.split(qkv_split, local_dim, dim=-1)
# Shapes (B, T, mp_num, local_dim), local_dim = n_head * head_dim // mp_num
query = self._split_heads(query, self.num_attention_heads, self.head_dim, mp_num=mp_num)
key = self._split_heads(key, self.num_attention_heads, self.head_dim, mp_num=mp_num)

value = self._split_heads(value, self.num_attention_heads, self.head_dim, mp_num=mp_num)
value = value.permute(0, 2, 1, 3)
# query, key, value: (B, T, n_head, head_dim)
value = value.transpose(1, 2) # (B, n_head, T, head_dim)

embed_positions = self.embed_positions
if embed_positions.device != position_ids.device:
embed_positions = embed_positions.to(position_ids.device)
self.embed_positions = embed_positions

sincos = embed_positions[position_ids]
if position_ids.dim() == 1:
position_ids = position_ids.unsqueeze(0)
embed_positions = embed_positions.unsqueeze(0).repeat(position_ids.shape[0], 1, 1)
repeated_position_ids = position_ids.unsqueeze(-1).repeat(1, 1, embed_positions.shape[-1])
sincos = torch.gather(embed_positions, 1, repeated_position_ids)
sin, cos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1)
sin = sin.unsqueeze(2)
cos = cos.unsqueeze(2)
# cos, sin: (B, T, 1, rotary_dim // 2)

if self.rotary_dim is not None:
k_rot = key[:, :, :, : self.rotary_dim]
k_pass = key[:, :, :, self.rotary_dim :]
k_rot = key[..., : self.rotary_dim]
k_pass = key[..., self.rotary_dim :]

q_rot = query[:, :, :, : self.rotary_dim]
q_pass = query[:, :, :, self.rotary_dim :]
q_rot = query[..., : self.rotary_dim]
q_pass = query[..., self.rotary_dim :]

k_rot = apply_rotary_pos_emb(k_rot, sin, cos)
q_rot = apply_rotary_pos_emb(q_rot, sin, cos)
Expand All @@ -190,8 +201,9 @@ def forward(
key = apply_rotary_pos_emb(key, sin, cos)
query = apply_rotary_pos_emb(query, sin, cos)

key = key.permute(0, 2, 1, 3)
query = query.permute(0, 2, 1, 3)
key = key.transpose(1, 2)
query = query.transpose(1, 2)
# query, key, value: (B, n_head, T, head_dim)

# Note that this cast is quite ugly, but is not implemented before ROPE as k_rot in the original codebase is always in fp32.
# Reference: https://github.com/salesforce/CodeGen/blob/f210c3bb1216c975ad858cd4132c0fdeabf4bfc2/codegen1/jaxformer/hf/codegen/modeling_codegen.py#L38
Expand All @@ -207,7 +219,7 @@ def forward(
# compute self-attention: V x Softmax(QK^T)
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)

attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_dim)
attn_output = self._merge_heads(attn_output)
attn_output = self.out_proj(attn_output)
attn_output = self.resid_dropout(attn_output)

Expand Down
Loading

0 comments on commit 2901c93

Please sign in to comment.