Skip to content

Commit

Permalink
finish fixing
Browse files Browse the repository at this point in the history
  • Loading branch information
ArthurZucker committed Oct 22, 2024
1 parent 314ed1f commit 8830473
Show file tree
Hide file tree
Showing 7 changed files with 17 additions and 23 deletions.
1 change: 1 addition & 0 deletions src/transformers/models/gemma/configuration_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from ...configuration_utils import PretrainedConfig


Expand Down
1 change: 0 additions & 1 deletion src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from typing import List, Optional, Tuple, Union

import torch
import torch.utils.checkpoint
from torch import nn

from ...activations import ACT2FN
Expand Down
24 changes: 14 additions & 10 deletions src/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,36 +19,40 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union
from typing import List, Optional, Tuple, Union

import torch
import torch.nn as nn

from ...activations import ACT2FN
from ...cache_utils import Cache, HybridCache
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...generation import GenerationMixin
from ...modeling_flash_attention_utils import _flash_attention_forward
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal,
is_flash_attn_greater_or_equal_2_10,
is_torch_greater_or_equal,
logging,
replace_return_docstrings,
)
from .configuration_gemma2 import Gemma2Config


if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward

if is_torch_greater_or_equal("2.5"):
from torch.nn.attention.flex_attention import flex_attention
from typing import List

from ...generation import GenerationMixin
from ...modeling_flash_attention_utils import _flash_attention_forward
from ...modeling_outputs import SequenceClassifierOutputWithPast, TokenClassifierOutput
from ...modeling_utils import PreTrainedModel
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
from .configuration_gemma2 import Gemma2Config


class Gemma2RMSNorm(nn.Module):
Expand Down
3 changes: 1 addition & 2 deletions src/transformers/models/glm/modeling_glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@
from typing import List, Optional, Tuple, Union

import torch
import torch.utils.checkpoint
from torch import nn
import torch.nn as nn

from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import os
from typing import Union

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import math
from dataclasses import dataclass
from typing import Any, Optional, Tuple, Union

import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,14 @@

import numpy as np
import torch
import torch.utils.checkpoint
from torch import nn

from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...image_processing_utils import select_best_resolution
from ...modeling_outputs import ModelOutput
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from ..auto import AutoModel, AutoModelForCausalLM
from .configuration_llava_next_video import LlavaNextVideoConfig

Expand Down

0 comments on commit 8830473

Please sign in to comment.