Skip to content

Commit

Permalink
Modular changes
Browse files Browse the repository at this point in the history
  • Loading branch information
LysandreJik committed Dec 12, 2024
1 parent a691ccb commit 1ac58da
Show file tree
Hide file tree
Showing 17 changed files with 116 additions and 252 deletions.
113 changes: 10 additions & 103 deletions src/transformers/models/gemma/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -13,111 +13,18 @@
# limitations under the License.
from typing import TYPE_CHECKING

from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_flax_available,
is_sentencepiece_available,
is_tokenizers_available,
is_torch_available,
)


_import_structure = {
"configuration_gemma": ["GemmaConfig"],
}

try:
if not is_sentencepiece_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["tokenization_gemma"] = ["GemmaTokenizer"]

try:
if not is_tokenizers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["tokenization_gemma_fast"] = ["GemmaTokenizerFast"]


try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_gemma"] = [
"GemmaForCausalLM",
"GemmaModel",
"GemmaPreTrainedModel",
"GemmaForSequenceClassification",
"GemmaForTokenClassification",
]

try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_flax_gemma"] = [
"FlaxGemmaForCausalLM",
"FlaxGemmaModel",
"FlaxGemmaPreTrainedModel",
]
from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure


if TYPE_CHECKING:
from .configuration_gemma import GemmaConfig

try:
if not is_sentencepiece_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .tokenization_gemma import GemmaTokenizer

try:
if not is_tokenizers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .tokenization_gemma_fast import GemmaTokenizerFast

try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_gemma import (
GemmaForCausalLM,
GemmaForSequenceClassification,
GemmaForTokenClassification,
GemmaModel,
GemmaPreTrainedModel,
)

try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_flax_gemma import (
FlaxGemmaForCausalLM,
FlaxGemmaModel,
FlaxGemmaPreTrainedModel,
)


from .configuration_gemma import *
from .modeling_flax_gemma import *
from .modeling_gemma import *
from .tokenization_gemma import *
from .tokenization_gemma_fast import *
else:
import sys

sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
3 changes: 3 additions & 0 deletions src/transformers/models/gemma/modeling_flax_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,3 +772,6 @@ def update_inputs_for_generation(self, model_outputs, model_kwargs):
_CONFIG_FOR_DOC,
real_checkpoint=_REAL_CHECKPOINT_FOR_DOC,
)


__all__ = ["FlaxGemmaForCausalLM", "FlaxGemmaModel", "FlaxGemmaPreTrainedModel"]
3 changes: 3 additions & 0 deletions src/transformers/models/gemma/tokenization_gemma_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,3 +197,6 @@ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
output = output + bos_token_id + token_ids_1 + eos_token_id

return output


__all__ = ["GemmaTokenizerFast"]
48 changes: 7 additions & 41 deletions src/transformers/models/gemma2/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -13,49 +13,15 @@
# limitations under the License.
from typing import TYPE_CHECKING

from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_torch_available,
)
from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure


_import_structure = {
"configuration_gemma2": ["Gemma2Config"],
}

try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_gemma2"] = [
"Gemma2ForCausalLM",
"Gemma2Model",
"Gemma2PreTrainedModel",
"Gemma2ForSequenceClassification",
"Gemma2ForTokenClassification",
]

if TYPE_CHECKING:
from .configuration_gemma2 import Gemma2Config

try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_gemma2 import (
Gemma2ForCausalLM,
Gemma2ForSequenceClassification,
Gemma2ForTokenClassification,
Gemma2Model,
Gemma2PreTrainedModel,
)

from .configuration_gemma2 import *
from .modeling_gemma2 import *
else:
import sys

sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
3 changes: 3 additions & 0 deletions src/transformers/models/gemma2/configuration_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,6 @@ def __init__(
self.final_logit_softcapping = final_logit_softcapping
self.attn_logit_softcapping = attn_logit_softcapping
self.cache_implementation = cache_implementation


__all__ = ["Gemma2Config"]
9 changes: 9 additions & 0 deletions src/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1280,3 +1280,12 @@ def forward(
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)


__all__ = [
"Gemma2ForCausalLM",
"Gemma2Model",
"Gemma2PreTrainedModel",
"Gemma2ForSequenceClassification",
"Gemma2ForTokenClassification",
]
10 changes: 10 additions & 0 deletions src/transformers/models/gemma2/modular_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -903,3 +903,13 @@ def __init__(self, config):
super().__init__(config)
self.model = Gemma2Model(config)
self.post_init()


__all__ = [
"Gemma2Config",
"Gemma2ForCausalLM",
"Gemma2Model",
"Gemma2PreTrainedModel",
"Gemma2ForSequenceClassification",
"Gemma2ForTokenClassification",
]
57 changes: 8 additions & 49 deletions src/transformers/models/llava_next_video/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,58 +13,17 @@
# limitations under the License.
from typing import TYPE_CHECKING

from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure


_import_structure = {
"configuration_llava_next_video": ["LlavaNextVideoConfig"],
"processing_llava_next_video": ["LlavaNextVideoProcessor"],
}


try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["image_processing_llava_next_video"] = ["LlavaNextVideoImageProcessor"]

try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_llava_next_video"] = [
"LlavaNextVideoForConditionalGeneration",
"LlavaNextVideoPreTrainedModel",
]

if TYPE_CHECKING:
from .configuration_llava_next_video import LlavaNextVideoConfig
from .processing_llava_next_video import LlavaNextVideoProcessor

try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .image_processing_llava_next_video import LlavaNextVideoImageProcessor

try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_llava_next_video import (
LlavaNextVideoForConditionalGeneration,
LlavaNextVideoPreTrainedModel,
)

from .configuration_llava_next_video import *
from .image_processing_llava_next_video import *
from .modeling_llava_next_video import *
from .processing_llava_next_video import *
else:
import sys

sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,6 @@ def __init__(
self.text_config = text_config

super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)


__all__ = ["LlavaNextVideoConfig"]
Original file line number Diff line number Diff line change
Expand Up @@ -414,3 +414,6 @@ def preprocess(

data = {"pixel_values_videos": pixel_values}
return BatchFeature(data=data, tensor_type=return_tensors)


__all__ = ["LlavaNextVideoImageProcessor"]
Original file line number Diff line number Diff line change
Expand Up @@ -122,21 +122,6 @@ def forward(self, image_features):
return image_features_spatial_pool.flatten(2).transpose(1, 2).contiguous()


class LlavaNextVideoMultiModalProjector(nn.Module):
def __init__(self, config: LlavaNextVideoConfig):
super().__init__()

self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True)
self.act = ACT2FN[config.projector_hidden_act]
self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)

def forward(self, image_features):
hidden_states = self.linear_1(image_features)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states


LLAVA_NEXT_VIDEO_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
Expand Down Expand Up @@ -191,6 +176,21 @@ def _init_weights(self, module):
module.weight.data[module.padding_idx].zero_()


class LlavaNextVideoMultiModalProjector(nn.Module):
def __init__(self, config: LlavaNextVideoConfig):
super().__init__()

self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True)
self.act = ACT2FN[config.projector_hidden_act]
self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)

def forward(self, image_features):
hidden_states = self.linear_1(image_features)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states


def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
"""
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
Expand Down Expand Up @@ -1157,3 +1157,6 @@ def get_video_features(
video_features = self.multi_modal_projector(video_features)
video_features = torch.split(video_features, frames, dim=0)
return video_features


__all__ = ["LlavaNextVideoForConditionalGeneration", "LlavaNextVideoPreTrainedModel"]
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from transformers.models.llava_next.modeling_llava_next import (
LlavaNextCausalLMOutputWithPast,
LlavaNextForConditionalGeneration,
LlavaNextPreTrainedModel,
image_size_to_num_patches,
)

Expand Down Expand Up @@ -218,6 +219,10 @@ def forward(self, image_features):
return image_features_spatial_pool.flatten(2).transpose(1, 2).contiguous()


class LlavaNextVideoPreTrainedModel(LlavaNextPreTrainedModel):
pass


class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
def __init__(self, config: LlavaNextVideoConfig, **super_kwargs):
super().__init__(config, **super_kwargs)
Expand Down Expand Up @@ -641,3 +646,6 @@ def prepare_inputs_for_generation(
model_inputs["image_sizes"] = image_sizes

return model_inputs


__all__ = ["LlavaNextVideoConfig", "LlavaNextVideoForConditionalGeneration", "LlavaNextVideoPreTrainedModel"]
Loading

0 comments on commit 1ac58da

Please sign in to comment.