diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 56985dee5ac0aa..7217b60a2e8406 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -570,7 +570,6 @@ "models.m2m_100": ["M2M100Config"], "models.mamba": ["MambaConfig"], "models.mamba2": ["Mamba2Config"], - "models.xlstm": ["xLSTMConfig"], "models.marian": ["MarianConfig"], "models.markuplm": [ "MarkupLMConfig", @@ -869,6 +868,7 @@ "models.xlm_roberta": ["XLMRobertaConfig"], "models.xlm_roberta_xl": ["XLMRobertaXLConfig"], "models.xlnet": ["XLNetConfig"], + "models.xlstm": ["xLSTMConfig"], "models.xmod": ["XmodConfig"], "models.yolos": ["YolosConfig"], "models.yoso": ["YosoConfig"], @@ -2732,13 +2732,6 @@ "Mamba2PreTrainedModel", ] ) - _import_structure["models.xlstm"].extend( - [ - "xLSTMForCausalLM", - "xLSTMModel", - "xLSTMPreTrainedModel", - ] - ) _import_structure["models.marian"].extend( ["MarianForCausalLM", "MarianModel", "MarianMTModel", "MarianPreTrainedModel"] ) @@ -3861,6 +3854,13 @@ "load_tf_weights_in_xlnet", ] ) + _import_structure["models.xlstm"].extend( + [ + "xLSTMForCausalLM", + "xLSTMModel", + "xLSTMPreTrainedModel", + ] + ) _import_structure["models.xmod"].extend( [ "XmodForCausalLM", diff --git a/src/transformers/models/xlstm/configuration_xlstm.py b/src/transformers/models/xlstm/configuration_xlstm.py index b841386e6dbce9..00398585537670 100644 --- a/src/transformers/models/xlstm/configuration_xlstm.py +++ b/src/transformers/models/xlstm/configuration_xlstm.py @@ -14,10 +14,9 @@ # limitations under the License. """XLSTM configuration""" - - from ...configuration_utils import PretrainedConfig -from ...utils import logging, is_xlstm_available +from ...utils import is_xlstm_available, logging + if is_xlstm_available(): from xlstm.xlstm_large.model import ( diff --git a/src/transformers/models/xlstm/modeling_xlstm.py b/src/transformers/models/xlstm/modeling_xlstm.py index 80ffbd5d71ff41..135ea4f97fb0a3 100644 --- a/src/transformers/models/xlstm/modeling_xlstm.py +++ b/src/transformers/models/xlstm/modeling_xlstm.py @@ -8,8 +8,6 @@ from torch import nn from torch.nn import CrossEntropyLoss - - from ...generation import GenerationMixin from ...modeling_utils import PreTrainedModel from ...utils import ( @@ -21,13 +19,14 @@ is_xlstm_available, ) + if is_xlstm_available(): from xlstm.xlstm_large.model import ( RMSNorm, mLSTMBlock, mLSTMStateType, soft_cap, - ) + ) else: mLSTMBlock = None @@ -421,7 +420,11 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict inserted_bos_token = False - if (cache_params is None or cache_params.rnn_state_initial) and self.config.force_bos_token_insert and input_ids is not None: + if ( + (cache_params is None or cache_params.rnn_state_initial) + and self.config.force_bos_token_insert + and input_ids is not None + ): if not is_torchdynamo_compiling(): if bool(torch.all(input_ids[0, 0] != self.config.bos_token_id).cpu()): input_ids = torch.cat( @@ -444,9 +447,21 @@ def forward( if inserted_bos_token: hidden_states = hidden_states[:, 1:] if hasattr(xlstm_outputs, "hidden_states"): - xlstm_outputs_mod = (hidden_states, xlstm_outputs.cache_params, tuple(hidden_state[:, 1:] for hidden_state in xlstm_outputs.hidden_states) if xlstm_outputs.hidden_states is not None else None) - elif len(xlstm_outputs) == 3: - xlstm_outputs_mod = (hidden_states, xlstm_outputs[1], tuple(hidden_state[:, 1:] for hidden_state in xlstm_outputs[2]) if xlstm_outputs[2] is not None else None) + xlstm_outputs_mod = ( + hidden_states, + xlstm_outputs.cache_params, + tuple(hidden_state[:, 1:] for hidden_state in xlstm_outputs.hidden_states) + if xlstm_outputs.hidden_states is not None + else None, + ) + elif len(xlstm_outputs) == 3: + xlstm_outputs_mod = ( + hidden_states, + xlstm_outputs[1], + tuple(hidden_state[:, 1:] for hidden_state in xlstm_outputs[2]) + if xlstm_outputs[2] is not None + else None, + ) elif len(xlstm_outputs) == 2: xlstm_outputs_mod = (hidden_states, xlstm_outputs[1]) elif len(xlstm_outputs) == 1: diff --git a/tests/models/xlstm/test_modeling_xlstm.py b/tests/models/xlstm/test_modeling_xlstm.py index 7ddd238b3d90cc..5bdb1c6f96b8a2 100644 --- a/tests/models/xlstm/test_modeling_xlstm.py +++ b/tests/models/xlstm/test_modeling_xlstm.py @@ -35,7 +35,7 @@ xLSTMForCausalLM, xLSTMModel, ) - from transformers.models.xlstm.modeling_xlstm import xLSTMCache, mLSTMBlock + from transformers.models.xlstm.modeling_xlstm import mLSTMBlock, xLSTMCache from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_2 else: is_torch_greater_or_equal_than_2_2 = False @@ -53,7 +53,7 @@ def __init__( vocab_size=99, embedding_dim=128, qk_dim_factor=0.5, - v_dim_factor=1., + v_dim_factor=1.0, num_blocks=2, max_position_embeddings=512, type_vocab_size=16, @@ -94,7 +94,6 @@ def __init__( self.num_hidden_layers = self.num_blocks self.hidden_size = self.embedding_dim - def get_large_model_config(self): cfg = xLSTMConfig.from_pretrained("NX-AI/xLSTM-7b") # this is needed for compatibility with generic tests @@ -102,9 +101,7 @@ def get_large_model_config(self): cfg.num_hidden_layers = cfg.num_blocks return cfg - def prepare_config_and_inputs( - self, scale_attn_by_inverse_layer_idx=False, reorder_and_upcast_attn=False - ): + def prepare_config_and_inputs(self, scale_attn_by_inverse_layer_idx=False, reorder_and_upcast_attn=False): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) sequence_labels = None