From fc8306343bade0171fd4947d289d2ad1d57572ab Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Tue, 28 Nov 2023 10:06:33 +0100 Subject: [PATCH] Use native torch rotary embeddings --- .../convert_cogvlm_original_to_pytorch.py | 5 +- .../models/cogvlm/modeling_cogvlm.py | 63 +++++++++++++++++-- tests/models/cogvlm/test_processor_cogvlm.py | 8 ++- 3 files changed, 70 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/cogvlm/convert_cogvlm_original_to_pytorch.py b/src/transformers/models/cogvlm/convert_cogvlm_original_to_pytorch.py index a6db8e506b1fde..164cf009215fef 100644 --- a/src/transformers/models/cogvlm/convert_cogvlm_original_to_pytorch.py +++ b/src/transformers/models/cogvlm/convert_cogvlm_original_to_pytorch.py @@ -32,6 +32,7 @@ ) from transformers.utils.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD + original_device = "cuda:1" hf_device = "cuda:2" @@ -110,7 +111,9 @@ def gather_inputs(inputs, device, use_bfloat16=True): image_std=OPENAI_CLIP_STD, ) patch_size = original_model.config.vision_config["patch_size"] - processor = CogVLMProcessor(image_processor=image_processor, tokenizer=tokenizer, image_size=image_size, patch_size=patch_size) + processor = CogVLMProcessor( + image_processor=image_processor, tokenizer=tokenizer, image_size=image_size, patch_size=patch_size + ) original_inputs = gather_inputs(inputs, device=hf_device) original_inputs["pixel_values"] = torch.stack(original_inputs.pop("images")[0]) diff --git a/src/transformers/models/cogvlm/modeling_cogvlm.py b/src/transformers/models/cogvlm/modeling_cogvlm.py index 4d407fac137749..c7593486609d0b 100644 --- a/src/transformers/models/cogvlm/modeling_cogvlm.py +++ b/src/transformers/models/cogvlm/modeling_cogvlm.py @@ -344,6 +344,44 @@ def attention_fn( return context_layer, attention_scores +class CogVLMRotaryEmbedding(torch.nn.Module): + def __init__(self, dim, base=10000, precision=torch.half, learnable=False, device=torch.device('cpu')): + super().__init__() + inv_freq = 1. / (base ** (torch.arange(0, dim, 2, device=device).float() / dim)) + self.learnable = learnable + if learnable: + self.inv_freq = torch.nn.Parameter(inv_freq) + self.max_seq_len_cached = None + else: + self.register_buffer('inv_freq', inv_freq) + self.max_seq_len_cached = None + self.cos_cached = None + self.sin_cached = None + self.precision = precision + + def forward(self, x, seq_dim=1, seq_len=None): + if seq_len is None: + seq_len = x.shape[seq_dim] + if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached): + self.max_seq_len_cached = None if self.learnable else seq_len + t = torch.arange(seq_len, device=x.device, dtype=torch.float32) + freqs = torch.einsum('i,j->ij', t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + if self.precision == torch.bfloat16: + emb = emb.float() + + # [sx, 1 (b * np), hn] + cos_cached = emb.cos()[:, None, :] + sin_cached = emb.sin()[:, None, :] + cos_cached = cos_cached.to(x.dtype) + sin_cached = sin_cached.to(x.dtype) + if self.learnable: + return cos_cached, sin_cached + self.cos_cached, self.sin_cached = cos_cached, sin_cached + return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...] + + class CogVLMVisionExpertAttention(nn.Module): def __init__(self, config): super().__init__() @@ -353,7 +391,8 @@ def __init__(self, config): self.head_dim = self.hidden_size // self.num_heads self.max_position_embeddings = config.max_position_embeddings - self.rotary_emb = FastRotaryEmbedding(dim=self.head_dim, pos_idx_in_fp32=False) + self.rotary_emb = CogVLMRotaryEmbedding(dim=self.head_dim) + # self.rotary_emb = FastRotaryEmbedding(dim=self.head_dim, pos_idx_in_fp32=False) self.vision_expert_query_key_value = nn.Linear(self.hidden_size, self.hidden_size * 3, bias=False) self.vision_expert_dense = nn.Linear(self.hidden_size, self.hidden_size, bias=False) self.language_expert_query_key_value = nn.Linear(self.hidden_size, self.hidden_size * 3, bias=False) @@ -394,9 +433,11 @@ def forward( if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] - query_states, key_states = self.rotary_emb( - query_states, key_states, position_ids=position_ids, max_seqlen=position_ids.max() + 1 - ) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb_index_bhs(query_states, key_states, cos, sin, position_ids) + # query_states, key_states = self.rotary_emb( + # query_states, key_states, position_ids=position_ids, max_seqlen=position_ids.max() + 1 + # ) if past_key_value is not None: key_states = torch.cat([past_key_value[0], key_states], dim=2) @@ -428,6 +469,20 @@ def forward( if output_attentions else (attn_output, None, past_key_value) ) + + +def rotate_half(x): + x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=x.ndim-1) # dim=-1 triggers a bug in earlier torch versions + + +def apply_rotary_pos_emb_index_bhs(q, k, cos, sin, position_id): + # batch_size, num_head, seq_len, hidden_size + cos, sin = nn.functional.embedding(position_id, cos.squeeze(1)).unsqueeze(1), \ + nn.functional.embedding(position_id, sin.squeeze(1)).unsqueeze(1) + q = (q * cos) + (rotate_half(q) * sin) + k = (k * cos) + (rotate_half(k) * sin) + return q, k class CogVLMDecoderLayer(nn.Module): diff --git a/tests/models/cogvlm/test_processor_cogvlm.py b/tests/models/cogvlm/test_processor_cogvlm.py index b457aae7d321cf..339c94a94be976 100644 --- a/tests/models/cogvlm/test_processor_cogvlm.py +++ b/tests/models/cogvlm/test_processor_cogvlm.py @@ -25,7 +25,13 @@ if is_vision_available(): from PIL import Image - from transformers import AutoProcessor, CLIPImageProcessor, CogVLMProcessor, LlamaTokenizer, PreTrainedTokenizerFast + from transformers import ( + AutoProcessor, + CLIPImageProcessor, + CogVLMProcessor, + LlamaTokenizer, + PreTrainedTokenizerFast, + ) @require_vision