Skip to content

Commit

Permalink
Refactor and fix style
Browse files Browse the repository at this point in the history
  • Loading branch information
yonigozlan committed Nov 14, 2024
1 parent 991b3bd commit 8c8c882
Show file tree
Hide file tree
Showing 8 changed files with 104 additions and 223 deletions.
2 changes: 1 addition & 1 deletion src/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@
git,
glm,
glpn,
got_ocr2,
gpt2,
gpt_bigcode,
gpt_neo,
Expand Down Expand Up @@ -204,7 +205,6 @@
pvt,
pvt_v2,
qwen2,
got_ocr2,
qwen2_audio,
qwen2_moe,
qwen2_vl,
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@
("git", "GitConfig"),
("glm", "GlmConfig"),
("glpn", "GLPNConfig"),
("got-ocr2", "GotOcr2Config"),
("gpt-sw3", "GPT2Config"),
("gpt2", "GPT2Config"),
("gpt_bigcode", "GPTBigCodeConfig"),
Expand Down Expand Up @@ -223,7 +224,6 @@
("pvt_v2", "PvtV2Config"),
("qdqbert", "QDQBertConfig"),
("qwen2", "Qwen2Config"),
("got-ocr2", "GotOcr2Config"),
("qwen2_audio", "Qwen2AudioConfig"),
("qwen2_audio_encoder", "Qwen2AudioEncoderConfig"),
("qwen2_moe", "Qwen2MoeConfig"),
Expand Down Expand Up @@ -420,6 +420,7 @@
("git", "GIT"),
("glm", "GLM"),
("glpn", "GLPN"),
("got-ocr2", "GOT-OCR2"),
("gpt-sw3", "GPT-Sw3"),
("gpt2", "OpenAI GPT-2"),
("gpt_bigcode", "GPTBigCode"),
Expand Down Expand Up @@ -540,7 +541,6 @@
("pvt_v2", "PVTv2"),
("qdqbert", "QDQBert"),
("qwen2", "Qwen2"),
("got-ocr2", "GOT-OCR2"),
("qwen2_audio", "Qwen2Audio"),
("qwen2_audio_encoder", "Qwen2AudioEncoder"),
("qwen2_moe", "Qwen2MoE"),
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@
("git", "GitModel"),
("glm", "GlmModel"),
("glpn", "GLPNModel"),
("got-ocr2", "GotOcr2Model"),
("gpt-sw3", "GPT2Model"),
("gpt2", "GPT2Model"),
("gpt_bigcode", "GPTBigCodeModel"),
Expand Down Expand Up @@ -209,7 +210,6 @@
("pvt_v2", "PvtV2Model"),
("qdqbert", "QDQBertModel"),
("qwen2", "Qwen2Model"),
("got-ocr2", "GotOcr2Model"),
("qwen2_audio_encoder", "Qwen2AudioEncoder"),
("qwen2_moe", "Qwen2MoeModel"),
("qwen2_vl", "Qwen2VLModel"),
Expand Down Expand Up @@ -489,6 +489,7 @@
("gemma2", "Gemma2ForCausalLM"),
("git", "GitForCausalLM"),
("glm", "GlmForCausalLM"),
("got-ocr2", "GotOcr2ForConditionalGeneration"),
("gpt-sw3", "GPT2LMHeadModel"),
("gpt2", "GPT2LMHeadModel"),
("gpt_bigcode", "GPTBigCodeForCausalLM"),
Expand Down Expand Up @@ -530,7 +531,6 @@
("prophetnet", "ProphetNetForCausalLM"),
("qdqbert", "QDQBertLMHeadModel"),
("qwen2", "Qwen2ForCausalLM"),
("got-ocr2", "GotOcr2ForConditionalGeneration"),
("qwen2_moe", "Qwen2MoeForCausalLM"),
("recurrent_gemma", "RecurrentGemmaForCausalLM"),
("reformer", "ReformerModelWithLMHead"),
Expand Down
14 changes: 7 additions & 7 deletions src/transformers/models/auto/tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,13 @@
),
("git", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
("glm", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
(
"got-ocr2",
(
"GotOcr2Tokenizer",
"GotOcr2TokenizerFast" if is_tokenizers_available() else None,
),
),
("gpt-sw3", ("GPTSw3Tokenizer" if is_sentencepiece_available() else None, None)),
("gpt2", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
("gpt_bigcode", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
Expand Down Expand Up @@ -406,13 +413,6 @@
"Qwen2TokenizerFast" if is_tokenizers_available() else None,
),
),
(
"got-ocr2",
(
"GotOcr2Tokenizer",
"GotOcr2TokenizerFast" if is_tokenizers_available() else None,
),
),
("qwen2_audio", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
(
"qwen2_moe",
Expand Down
23 changes: 12 additions & 11 deletions src/transformers/models/got_ocr2/convert_got_ocr2_weights_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,17 @@
# fmt: off
ORIGINAL_TO_CONVERTED_KEY_MAPPING = {
# Vision encoder mapping
r"model.vision_tower_high.pos_embed": r"visual.pos_embed",
r"model.vision_tower_high.patch_embed.proj": r"visual.patch_embed.projection",
r"model.vision_tower_high.blocks.(\d+).norm": r"visual.layers.\1.layer_norm",
r"model.vision_tower_high.blocks.(\d+).attn": r"visual.layers.\1.attn",
r"model.vision_tower_high.blocks.(\d+).mlp": r"visual.layers.\1.mlp",
r"model.vision_tower_high.neck.0": r"visual.neck.conv1",
r"model.vision_tower_high.neck.1": r"visual.neck.layer_norm1",
r"model.vision_tower_high.neck.2": r"visual.neck.conv2",
r"model.vision_tower_high.neck.3": r"visual.neck.layer_norm2",
r"model.vision_tower_high.net_(\d+).": r"visual_adapter.net_\1.",
r"model.mm_projector_vary" : r"visual_adapter.mm_projector_vary",
r"model.vision_tower_high.pos_embed": r"visual.pos_embed",
r"model.vision_tower_high.patch_embed.proj": r"visual.patch_embed.projection",
r"model.vision_tower_high.blocks.(\d+).norm": r"visual.layers.\1.layer_norm",
r"model.vision_tower_high.blocks.(\d+).attn": r"visual.layers.\1.attn",
r"model.vision_tower_high.blocks.(\d+).mlp": r"visual.layers.\1.mlp",
r"model.vision_tower_high.neck.0": r"visual.neck.conv1",
r"model.vision_tower_high.neck.1": r"visual.neck.layer_norm1",
r"model.vision_tower_high.neck.2": r"visual.neck.conv2",
r"model.vision_tower_high.neck.3": r"visual.neck.layer_norm2",
r"model.vision_tower_high.net_(\d+)": lambda m: f"visual_adapter.conv_up{int(m.group(1)) - 1}",
r"model.mm_projector_vary" : r"visual_adapter.multimodal_projector",
}
# fmt: on

Expand Down Expand Up @@ -103,6 +103,7 @@ def get_got_ocr2_config():
max_window_layers=21,
attention_dropout=0.0,
rope_scaling=None,
image_token_id=151859,
)

return config
Expand Down
99 changes: 23 additions & 76 deletions src/transformers/models/got_ocr2/modeling_got_ocr2.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import PreTrainedModel
from ...utils import (
Expand All @@ -52,18 +52,6 @@
from ...modeling_flash_attention_utils import _flash_attention_forward


if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
else:
flash_attn_varlen_func = None


if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
else:
flash_attn_varlen_func = None


logger = logging.get_logger(__name__)


Expand Down Expand Up @@ -98,19 +86,21 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:


class GotOcr2VisionAdapter(nn.Module):
def __init__(self, config: GotOcr2VisionConfig):
def __init__(self, language_hidden_size: int, vision_output_channels: int):
super().__init__()
self.config = config

self.net_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False)
self.net_3 = nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1, bias=False)
self.mm_projector_vary = nn.Linear(1024, 1024)
self.conv_up1 = nn.Conv2d(
vision_output_channels, vision_output_channels * 2, kernel_size=3, stride=2, padding=1, bias=False
)
self.conv_up2 = nn.Conv2d(
vision_output_channels * 2, language_hidden_size, kernel_size=3, stride=2, padding=1, bias=False
)
self.multimodal_projector = nn.Linear(language_hidden_size, language_hidden_size)

def forward(self, vision_embeddings):
x = self.net_2(vision_embeddings)
x = self.net_3(x)
x = self.conv_up1(vision_embeddings)
x = self.conv_up2(x)
x = x.flatten(2).permute(0, 2, 1)
x = self.mm_projector_vary(x)
x = self.multimodal_projector(x)
return x


Expand Down Expand Up @@ -1546,45 +1536,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
return causal_mask


@dataclass
class GotOcr2CausalLMOutputWithPast(ModelOutput):
"""
Base class for GotOcr2 causal language model (or autoregressive) outputs.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (for next-token prediction).
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
The rope index difference between sequence length and multimodal rope.
"""

loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
past_key_values: Optional[List[torch.FloatTensor]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
rope_deltas: Optional[torch.LongTensor] = None


class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]

Expand All @@ -1594,10 +1545,9 @@ def __init__(self, config):
self.model = GotOcr2Model(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.padding_side = "left" # set it to left by default, user can use setter to change padding_sides
self.visual_adapter = GotOcr2VisionAdapter(config.vision_config)
self.padding_side = "left"
self.visual_adapter = GotOcr2VisionAdapter(config.hidden_size, config.vision_config.output_channels)

# Initialize weights and apply final processing
self.post_init()

def _update_model_kwargs_for_generation(
Expand All @@ -1614,9 +1564,6 @@ def _update_model_kwargs_for_generation(
num_new_tokens=num_new_tokens,
)

if getattr(outputs, "rope_deltas", None) is not None:
model_kwargs["rope_deltas"] = outputs.rope_deltas

return model_kwargs

def forward(
Expand All @@ -1632,8 +1579,7 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
pixel_values: Optional[torch.Tensor] = None,
rope_deltas: Optional[torch.LongTensor] = None,
) -> Union[Tuple, GotOcr2CausalLMOutputWithPast]:
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Expand Down Expand Up @@ -1685,13 +1631,18 @@ def forward(
if pixel_values is not None:
image_embeds = self.visual(pixel_values)
image_embeds = self.visual_adapter(image_embeds.last_hidden_state)
n_image_tokens = (input_ids == 151859).sum().item()
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
n_image_features = image_embeds.shape[1]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
image_mask = (input_ids == 151859).unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
image_mask = (
(input_ids == self.config.image_token_id)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)

Expand Down Expand Up @@ -1732,13 +1683,12 @@ def forward(
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output

return GotOcr2CausalLMOutputWithPast(
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
rope_deltas=rope_deltas,
)

def prepare_inputs_for_generation(
Expand All @@ -1764,8 +1714,6 @@ def prepare_inputs_for_generation(
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
input_ids = input_ids[:, cache_position]

rope_deltas = kwargs.get("rope_deltas", None)

if cache_position[0] != 0:
pixel_values = None

Expand Down Expand Up @@ -1802,7 +1750,6 @@ def prepare_inputs_for_generation(
"use_cache": use_cache,
"attention_mask": attention_mask,
"pixel_values": pixel_values,
"rope_deltas": rope_deltas,
}
)
return model_inputs
Loading

0 comments on commit 8c8c882

Please sign in to comment.