Skip to content

Commit

Permalink
Fix: Import structure cleaning with Ruff.
Browse files Browse the repository at this point in the history
  • Loading branch information
kpoeppel committed Dec 20, 2024
1 parent 434e157 commit ba5c457
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 24 deletions.
16 changes: 8 additions & 8 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,6 @@
"models.m2m_100": ["M2M100Config"],
"models.mamba": ["MambaConfig"],
"models.mamba2": ["Mamba2Config"],
"models.xlstm": ["xLSTMConfig"],
"models.marian": ["MarianConfig"],
"models.markuplm": [
"MarkupLMConfig",
Expand Down Expand Up @@ -869,6 +868,7 @@
"models.xlm_roberta": ["XLMRobertaConfig"],
"models.xlm_roberta_xl": ["XLMRobertaXLConfig"],
"models.xlnet": ["XLNetConfig"],
"models.xlstm": ["xLSTMConfig"],
"models.xmod": ["XmodConfig"],
"models.yolos": ["YolosConfig"],
"models.yoso": ["YosoConfig"],
Expand Down Expand Up @@ -2732,13 +2732,6 @@
"Mamba2PreTrainedModel",
]
)
_import_structure["models.xlstm"].extend(
[
"xLSTMForCausalLM",
"xLSTMModel",
"xLSTMPreTrainedModel",
]
)
_import_structure["models.marian"].extend(
["MarianForCausalLM", "MarianModel", "MarianMTModel", "MarianPreTrainedModel"]
)
Expand Down Expand Up @@ -3861,6 +3854,13 @@
"load_tf_weights_in_xlnet",
]
)
_import_structure["models.xlstm"].extend(
[
"xLSTMForCausalLM",
"xLSTMModel",
"xLSTMPreTrainedModel",
]
)
_import_structure["models.xmod"].extend(
[
"XmodForCausalLM",
Expand Down
5 changes: 2 additions & 3 deletions src/transformers/models/xlstm/configuration_xlstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,9 @@
# limitations under the License.
"""XLSTM configuration"""



from ...configuration_utils import PretrainedConfig
from ...utils import logging, is_xlstm_available
from ...utils import is_xlstm_available, logging


if is_xlstm_available():
from xlstm.xlstm_large.model import (
Expand Down
29 changes: 22 additions & 7 deletions src/transformers/models/xlstm/modeling_xlstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
from torch import nn
from torch.nn import CrossEntropyLoss



from ...generation import GenerationMixin
from ...modeling_utils import PreTrainedModel
from ...utils import (
Expand All @@ -21,13 +19,14 @@
is_xlstm_available,
)


if is_xlstm_available():
from xlstm.xlstm_large.model import (
RMSNorm,
mLSTMBlock,
mLSTMStateType,
soft_cap,
)
)
else:
mLSTMBlock = None

Expand Down Expand Up @@ -421,7 +420,11 @@ def forward(
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

inserted_bos_token = False
if (cache_params is None or cache_params.rnn_state_initial) and self.config.force_bos_token_insert and input_ids is not None:
if (
(cache_params is None or cache_params.rnn_state_initial)
and self.config.force_bos_token_insert
and input_ids is not None
):
if not is_torchdynamo_compiling():
if bool(torch.all(input_ids[0, 0] != self.config.bos_token_id).cpu()):
input_ids = torch.cat(
Expand All @@ -444,9 +447,21 @@ def forward(
if inserted_bos_token:
hidden_states = hidden_states[:, 1:]
if hasattr(xlstm_outputs, "hidden_states"):
xlstm_outputs_mod = (hidden_states, xlstm_outputs.cache_params, tuple(hidden_state[:, 1:] for hidden_state in xlstm_outputs.hidden_states) if xlstm_outputs.hidden_states is not None else None)
elif len(xlstm_outputs) == 3:
xlstm_outputs_mod = (hidden_states, xlstm_outputs[1], tuple(hidden_state[:, 1:] for hidden_state in xlstm_outputs[2]) if xlstm_outputs[2] is not None else None)
xlstm_outputs_mod = (
hidden_states,
xlstm_outputs.cache_params,
tuple(hidden_state[:, 1:] for hidden_state in xlstm_outputs.hidden_states)
if xlstm_outputs.hidden_states is not None
else None,
)
elif len(xlstm_outputs) == 3:
xlstm_outputs_mod = (
hidden_states,
xlstm_outputs[1],
tuple(hidden_state[:, 1:] for hidden_state in xlstm_outputs[2])
if xlstm_outputs[2] is not None
else None,
)
elif len(xlstm_outputs) == 2:
xlstm_outputs_mod = (hidden_states, xlstm_outputs[1])
elif len(xlstm_outputs) == 1:
Expand Down
9 changes: 3 additions & 6 deletions tests/models/xlstm/test_modeling_xlstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
xLSTMForCausalLM,
xLSTMModel,
)
from transformers.models.xlstm.modeling_xlstm import xLSTMCache, mLSTMBlock
from transformers.models.xlstm.modeling_xlstm import mLSTMBlock, xLSTMCache
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_2
else:
is_torch_greater_or_equal_than_2_2 = False
Expand All @@ -53,7 +53,7 @@ def __init__(
vocab_size=99,
embedding_dim=128,
qk_dim_factor=0.5,
v_dim_factor=1.,
v_dim_factor=1.0,
num_blocks=2,
max_position_embeddings=512,
type_vocab_size=16,
Expand Down Expand Up @@ -94,17 +94,14 @@ def __init__(
self.num_hidden_layers = self.num_blocks
self.hidden_size = self.embedding_dim


def get_large_model_config(self):
cfg = xLSTMConfig.from_pretrained("NX-AI/xLSTM-7b")
# this is needed for compatibility with generic tests
cfg.hidden_size = cfg.embedding_dim
cfg.num_hidden_layers = cfg.num_blocks
return cfg

def prepare_config_and_inputs(
self, scale_attn_by_inverse_layer_idx=False, reorder_and_upcast_attn=False
):
def prepare_config_and_inputs(self, scale_attn_by_inverse_layer_idx=False, reorder_and_upcast_attn=False):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)

sequence_labels = None
Expand Down

0 comments on commit ba5c457

Please sign in to comment.