Skip to content

Commit

Permalink
Use native torch rotary embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
NielsRogge committed Nov 28, 2023
1 parent e9f78df commit fc83063
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
)
from transformers.utils.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD


original_device = "cuda:1"
hf_device = "cuda:2"

Expand Down Expand Up @@ -110,7 +111,9 @@ def gather_inputs(inputs, device, use_bfloat16=True):
image_std=OPENAI_CLIP_STD,
)
patch_size = original_model.config.vision_config["patch_size"]
processor = CogVLMProcessor(image_processor=image_processor, tokenizer=tokenizer, image_size=image_size, patch_size=patch_size)
processor = CogVLMProcessor(
image_processor=image_processor, tokenizer=tokenizer, image_size=image_size, patch_size=patch_size
)

original_inputs = gather_inputs(inputs, device=hf_device)
original_inputs["pixel_values"] = torch.stack(original_inputs.pop("images")[0])
Expand Down
63 changes: 59 additions & 4 deletions src/transformers/models/cogvlm/modeling_cogvlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,44 @@ def attention_fn(
return context_layer, attention_scores


class CogVLMRotaryEmbedding(torch.nn.Module):
def __init__(self, dim, base=10000, precision=torch.half, learnable=False, device=torch.device('cpu')):
super().__init__()
inv_freq = 1. / (base ** (torch.arange(0, dim, 2, device=device).float() / dim))
self.learnable = learnable
if learnable:
self.inv_freq = torch.nn.Parameter(inv_freq)
self.max_seq_len_cached = None
else:
self.register_buffer('inv_freq', inv_freq)
self.max_seq_len_cached = None
self.cos_cached = None
self.sin_cached = None
self.precision = precision

def forward(self, x, seq_dim=1, seq_len=None):
if seq_len is None:
seq_len = x.shape[seq_dim]
if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached):
self.max_seq_len_cached = None if self.learnable else seq_len
t = torch.arange(seq_len, device=x.device, dtype=torch.float32)
freqs = torch.einsum('i,j->ij', t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
if self.precision == torch.bfloat16:
emb = emb.float()

# [sx, 1 (b * np), hn]
cos_cached = emb.cos()[:, None, :]
sin_cached = emb.sin()[:, None, :]
cos_cached = cos_cached.to(x.dtype)
sin_cached = sin_cached.to(x.dtype)
if self.learnable:
return cos_cached, sin_cached
self.cos_cached, self.sin_cached = cos_cached, sin_cached
return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...]


class CogVLMVisionExpertAttention(nn.Module):
def __init__(self, config):
super().__init__()
Expand All @@ -353,7 +391,8 @@ def __init__(self, config):
self.head_dim = self.hidden_size // self.num_heads
self.max_position_embeddings = config.max_position_embeddings

self.rotary_emb = FastRotaryEmbedding(dim=self.head_dim, pos_idx_in_fp32=False)
self.rotary_emb = CogVLMRotaryEmbedding(dim=self.head_dim)
# self.rotary_emb = FastRotaryEmbedding(dim=self.head_dim, pos_idx_in_fp32=False)
self.vision_expert_query_key_value = nn.Linear(self.hidden_size, self.hidden_size * 3, bias=False)
self.vision_expert_dense = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
self.language_expert_query_key_value = nn.Linear(self.hidden_size, self.hidden_size * 3, bias=False)
Expand Down Expand Up @@ -394,9 +433,11 @@ def forward(
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]

query_states, key_states = self.rotary_emb(
query_states, key_states, position_ids=position_ids, max_seqlen=position_ids.max() + 1
)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb_index_bhs(query_states, key_states, cos, sin, position_ids)
# query_states, key_states = self.rotary_emb(
# query_states, key_states, position_ids=position_ids, max_seqlen=position_ids.max() + 1
# )

if past_key_value is not None:
key_states = torch.cat([past_key_value[0], key_states], dim=2)
Expand Down Expand Up @@ -428,6 +469,20 @@ def forward(
if output_attentions
else (attn_output, None, past_key_value)
)


def rotate_half(x):
x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=x.ndim-1) # dim=-1 triggers a bug in earlier torch versions


def apply_rotary_pos_emb_index_bhs(q, k, cos, sin, position_id):
# batch_size, num_head, seq_len, hidden_size
cos, sin = nn.functional.embedding(position_id, cos.squeeze(1)).unsqueeze(1), \
nn.functional.embedding(position_id, sin.squeeze(1)).unsqueeze(1)
q = (q * cos) + (rotate_half(q) * sin)
k = (k * cos) + (rotate_half(k) * sin)
return q, k


class CogVLMDecoderLayer(nn.Module):
Expand Down
8 changes: 7 additions & 1 deletion tests/models/cogvlm/test_processor_cogvlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,13 @@
if is_vision_available():
from PIL import Image

from transformers import AutoProcessor, CLIPImageProcessor, CogVLMProcessor, LlamaTokenizer, PreTrainedTokenizerFast
from transformers import (
AutoProcessor,
CLIPImageProcessor,
CogVLMProcessor,
LlamaTokenizer,
PreTrainedTokenizerFast,
)


@require_vision
Expand Down

0 comments on commit fc83063

Please sign in to comment.