From e15687fffe5c9d20598a19aeab721ae0a7580f8a Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Mon, 23 Sep 2024 18:28:36 +0100 Subject: [PATCH] Generation: deprecate `PreTrainedModel` inheriting from `GenerationMixin` (#33203) --- src/transformers/generation/utils.py | 34 ++++++------------ src/transformers/modeling_utils.py | 36 ++++++++++++++----- .../models/albert/modeling_albert.py | 3 +- src/transformers/models/auto/auto_factory.py | 35 ++++++++++++++++++ src/transformers/models/bark/modeling_bark.py | 3 +- src/transformers/models/bart/modeling_bart.py | 5 +-- src/transformers/models/bert/modeling_bert.py | 3 +- .../modeling_bert_generation.py | 3 +- .../models/big_bird/modeling_big_bird.py | 3 +- .../modeling_bigbird_pegasus.py | 5 +-- .../models/biogpt/modeling_biogpt.py | 3 +- .../models/blenderbot/modeling_blenderbot.py | 5 +-- .../modeling_blenderbot_small.py | 5 +-- src/transformers/models/blip/modeling_blip.py | 3 +- .../models/blip/modeling_blip_text.py | 3 +- .../models/blip_2/modeling_blip_2.py | 3 +- .../models/bloom/modeling_bloom.py | 3 +- .../models/camembert/modeling_camembert.py | 3 +- .../models/chameleon/modeling_chameleon.py | 3 +- src/transformers/models/clvp/modeling_clvp.py | 6 ++-- .../models/codegen/modeling_codegen.py | 3 +- .../models/cohere/modeling_cohere.py | 3 +- .../models/cpmant/modeling_cpmant.py | 3 +- src/transformers/models/ctrl/modeling_ctrl.py | 3 +- .../models/data2vec/modeling_data2vec_text.py | 3 +- src/transformers/models/dbrx/modeling_dbrx.py | 3 +- .../models/electra/modeling_electra.py | 3 +- .../models/ernie/modeling_ernie.py | 3 +- .../models/falcon/modeling_falcon.py | 3 +- .../falcon_mamba/modeling_falcon_mamba.py | 3 +- .../models/flaubert/modeling_flaubert.py | 3 +- src/transformers/models/fsmt/modeling_fsmt.py | 3 +- src/transformers/models/fuyu/modeling_fuyu.py | 3 +- src/transformers/models/gemma/diff_gemma.py | 3 +- .../models/gemma/modeling_gemma.py | 3 +- src/transformers/models/gemma2/diff_gemma2.py | 3 +- .../models/gemma2/modeling_gemma2.py | 3 +- src/transformers/models/git/modeling_git.py | 3 +- src/transformers/models/gpt2/modeling_gpt2.py | 5 +-- .../gpt_bigcode/modeling_gpt_bigcode.py | 3 +- .../models/gpt_neo/modeling_gpt_neo.py | 3 +- .../models/gpt_neox/modeling_gpt_neox.py | 3 +- .../modeling_gpt_neox_japanese.py | 3 +- src/transformers/models/gptj/modeling_gptj.py | 3 +- .../models/granite/modeling_granite.py | 3 +- .../models/granitemoe/modeling_granitemoe.py | 3 +- .../models/idefics2/modeling_idefics2.py | 5 +-- .../models/imagegpt/modeling_imagegpt.py | 3 +- .../instructblip/modeling_instructblip.py | 3 +- .../diff_instructblipvideo.py | 3 +- .../modeling_instructblipvideo.py | 3 +- .../models/jamba/modeling_jamba.py | 3 +- .../models/jetmoe/modeling_jetmoe.py | 3 +- .../models/kosmos2/modeling_kosmos2.py | 5 +-- src/transformers/models/led/modeling_led.py | 3 +- .../models/llama/modeling_llama.py | 3 +- .../models/llava/modeling_llava.py | 5 +-- .../models/llava_next/modeling_llava_next.py | 5 +-- .../llava_next_video/diff_llava_next_video.py | 3 +- .../modeling_llava_next_video.py | 5 +-- .../modeling_llava_onevision.py | 5 +-- .../models/longt5/modeling_longt5.py | 3 +- .../models/m2m_100/modeling_m2m_100.py | 3 +- .../models/mamba/modeling_mamba.py | 3 +- .../models/mamba2/modeling_mamba2.py | 3 +- .../models/marian/modeling_marian.py | 5 +-- .../models/mbart/modeling_mbart.py | 5 +-- .../megatron_bert/modeling_megatron_bert.py | 3 +- .../models/mistral/modeling_mistral.py | 3 +- .../models/mixtral/modeling_mixtral.py | 3 +- src/transformers/models/mpt/modeling_mpt.py | 3 +- src/transformers/models/mt5/modeling_mt5.py | 3 +- .../models/musicgen/modeling_musicgen.py | 15 +++++--- .../modeling_musicgen_melody.py | 15 +++++--- src/transformers/models/mvp/modeling_mvp.py | 5 +-- .../models/nemotron/modeling_nemotron.py | 3 +- .../models/nllb_moe/modeling_nllb_moe.py | 3 +- src/transformers/models/olmo/modeling_olmo.py | 3 +- .../models/olmoe/modeling_olmoe.py | 3 +- .../models/openai/modeling_openai.py | 3 +- src/transformers/models/opt/modeling_opt.py | 3 +- .../models/paligemma/modeling_paligemma.py | 3 +- .../models/pegasus/modeling_pegasus.py | 5 +-- .../models/pegasus_x/modeling_pegasus_x.py | 3 +- .../models/persimmon/modeling_persimmon.py | 3 +- src/transformers/models/phi/modeling_phi.py | 3 +- src/transformers/models/phi3/modeling_phi3.py | 3 +- .../models/pix2struct/modeling_pix2struct.py | 3 +- .../models/plbart/modeling_plbart.py | 5 +-- .../models/pop2piano/modeling_pop2piano.py | 3 +- .../models/prophetnet/modeling_prophetnet.py | 5 +-- .../models/qwen2/modeling_qwen2.py | 3 +- .../qwen2_audio/modeling_qwen2_audio.py | 5 +-- .../models/qwen2_moe/modeling_qwen2_moe.py | 3 +- .../models/qwen2_vl/modeling_qwen2_vl.py | 3 +- .../modeling_recurrent_gemma.py | 3 +- .../models/reformer/modeling_reformer.py | 3 +- .../models/rembert/modeling_rembert.py | 3 +- .../models/roberta/modeling_roberta.py | 3 +- .../modeling_roberta_prelayernorm.py | 3 +- .../models/roc_bert/modeling_roc_bert.py | 3 +- .../models/roformer/modeling_roformer.py | 3 +- src/transformers/models/rwkv/modeling_rwkv.py | 3 +- .../seamless_m4t/modeling_seamless_m4t.py | 5 +-- .../modeling_seamless_m4t_v2.py | 5 +-- .../speech_to_text/modeling_speech_to_text.py | 3 +- .../models/stablelm/modeling_stablelm.py | 3 +- .../models/starcoder2/modeling_starcoder2.py | 3 +- .../modeling_switch_transformers.py | 3 +- src/transformers/models/t5/modeling_t5.py | 3 +- .../models/trocr/modeling_trocr.py | 3 +- src/transformers/models/udop/modeling_udop.py | 3 +- src/transformers/models/umt5/modeling_umt5.py | 3 +- .../video_llava/modeling_video_llava.py | 5 +-- .../models/vipllava/modeling_vipllava.py | 5 +-- .../models/whisper/generation_whisper.py | 4 +-- .../models/whisper/modeling_whisper.py | 3 +- src/transformers/models/xglm/modeling_xglm.py | 3 +- src/transformers/models/xlm/modeling_xlm.py | 3 +- .../xlm_roberta/modeling_xlm_roberta.py | 3 +- .../xlm_roberta_xl/modeling_xlm_roberta_xl.py | 3 +- .../models/xlnet/modeling_xlnet.py | 3 +- src/transformers/models/xmod/modeling_xmod.py | 3 +- tests/generation/test_utils.py | 9 +++++ tests/models/auto/test_modeling_auto.py | 18 ++++++++++ tests/utils/test_modeling_utils.py | 27 ++++++++++++++ 126 files changed, 407 insertions(+), 184 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 2fe92d3e3ed64b..c1aa338a7d8f2f 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -34,13 +34,6 @@ ) from ..integrations.deepspeed import is_deepspeed_zero3_enabled from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput -from ..models.auto import ( - MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING, - MODEL_FOR_CAUSAL_LM_MAPPING, - MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, - MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING, - MODEL_FOR_VISION_2_SEQ_MAPPING, -) from ..pytorch_utils import isin_mps_friendly from ..tokenization_utils import ExtensionsTrie from ..utils import ( @@ -1117,26 +1110,21 @@ def _validate_model_class(self): Confirms that the model class is compatible with generation. If not, raises an exception that points to the right class to use. """ + # TODO(joao): remove this function in v4.50, i.e. when we remove the inheritance of `GenerationMixin` from + # `PreTrainedModel`. With that inheritance removed, all model classes inheriting from `GenerationMixin` can + # safely call `GenerationMixin.generate` if not is_torchdynamo_compiling() and not self.can_generate(): - generate_compatible_mappings = [ - MODEL_FOR_CAUSAL_LM_MAPPING, - MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING, - MODEL_FOR_VISION_2_SEQ_MAPPING, - MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, - MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING, + terminations_with_generation_support = [ + "ForCausalLM", + "ForConditionalGeneration", + "ForSpeechSeq2Seq", + "ForVision2Seq", ] - generate_compatible_classes = set() - for model_mapping in generate_compatible_mappings: - supported_models = model_mapping.get(type(self.config), default=None) - if supported_models is not None: - generate_compatible_classes.add(supported_models.__name__) - exception_message = ( + raise TypeError( f"The current model class ({self.__class__.__name__}) is not compatible with `.generate()`, as " - "it doesn't have a language model head." + "it doesn't have a language model head. Classes that support generation often end in one of these " + f"names: {terminations_with_generation_support}." ) - if generate_compatible_classes: - exception_message += f" Please use one of the following classes instead: {generate_compatible_classes}" - raise TypeError(exception_message) def _validate_assistant(self, assistant_model): if assistant_model is None: diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index d4069766636053..6fff23f6b6df13 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -212,7 +212,7 @@ def _skip_init(*args, **kwargs): setattr(torch.nn.init, name, init_func) -def get_parameter_device(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]): +def get_parameter_device(parameter: Union[nn.Module, "ModuleUtilsMixin"]): try: return next(parameter.parameters()).device except StopIteration: @@ -227,7 +227,7 @@ def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]: return first_tuple[1].device -def get_first_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]): +def get_first_parameter_dtype(parameter: Union[nn.Module, "ModuleUtilsMixin"]): """ Returns the first parameter dtype (can be non-floating) or asserts if none were found. """ @@ -245,7 +245,7 @@ def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]: return first_tuple[1].dtype -def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]): +def get_parameter_dtype(parameter: Union[nn.Module, "ModuleUtilsMixin"]): """ Returns the first found floating dtype in parameters if there is one, otherwise returns the last dtype it found. """ @@ -1309,6 +1309,7 @@ def floating_point_ops( return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings) +# TODO (joao): remove `GenerationMixin` inheritance in v4.50 class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin, PeftAdapterMixin): r""" Base class for all models. @@ -1638,11 +1639,30 @@ def can_generate(cls) -> bool: Returns: `bool`: Whether this model can generate sequences with `.generate()`. """ - # Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation. - # Alternativelly, the model can also have a custom `generate` function. - if "GenerationMixin" in str(cls.prepare_inputs_for_generation) and "GenerationMixin" in str(cls.generate): - return False - return True + # Directly inherits `GenerationMixin` -> can generate + if "GenerationMixin" in str(cls.__bases__): + return True + # Model class overwrites `generate` (e.g. time series models) -> can generate + if str(cls.__name__) in str(cls.generate): + return True + # BC: Detects whether `prepare_inputs_for_generation` has been overwritten in the model. Prior to v4.45, this + # was how we detected whether a model could generate. + if "GenerationMixin" not in str(cls.prepare_inputs_for_generation): + logger.warning_once( + f"{cls.__name__} has generative capabilities, as `prepare_inputs_for_generation` is explicitly " + "overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, " + "`PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability " + "to call `generate` and other related functions." + "\n - If you're using `trust_remote_code=True`, you can get rid of this warning by loading the " + "model with an auto class. See https://huggingface.co/docs/transformers/en/model_doc/auto#auto-classes" + "\n - If you are the owner of the model architecture code, please modify your model class such that " + "it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception)." + "\n - If you are not the owner of the model architecture class, please contact the model code owner " + "to update it." + ) + return True + # Otherwise, can't generate + return False @classmethod def _check_and_enable_flash_attn_2( diff --git a/src/transformers/models/albert/modeling_albert.py b/src/transformers/models/albert/modeling_albert.py index dca1fe7f600295..bfd8e38687accc 100755 --- a/src/transformers/models/albert/modeling_albert.py +++ b/src/transformers/models/albert/modeling_albert.py @@ -24,6 +24,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa from ...modeling_outputs import ( BaseModelOutput, @@ -983,7 +984,7 @@ def forward(self, pooled_output: torch.Tensor) -> torch.Tensor: "Albert Model with a `language modeling` head on top.", ALBERT_START_DOCSTRING, ) -class AlbertForMaskedLM(AlbertPreTrainedModel): +class AlbertForMaskedLM(AlbertPreTrainedModel, GenerationMixin): _tied_weights_keys = ["predictions.decoder.bias", "predictions.decoder.weight"] def __init__(self, config): diff --git a/src/transformers/models/auto/auto_factory.py b/src/transformers/models/auto/auto_factory.py index 220ae97f5073c6..7809b2a6cc2cfc 100644 --- a/src/transformers/models/auto/auto_factory.py +++ b/src/transformers/models/auto/auto_factory.py @@ -29,12 +29,17 @@ extract_commit_hash, find_adapter_config_file, is_peft_available, + is_torch_available, logging, requires_backends, ) from .configuration_auto import AutoConfig, model_type_to_module_name, replace_list_option_in_docstrings +if is_torch_available(): + from ...generation import GenerationMixin + + logger = logging.get_logger(__name__) @@ -428,6 +433,7 @@ def from_config(cls, config, **kwargs): model_class = get_class_from_dynamic_module(class_ref, repo_id, **kwargs) cls.register(config.__class__, model_class, exist_ok=True) _ = kwargs.pop("code_revision", None) + model_class = add_generation_mixin_to_remote_model(model_class) return model_class._from_config(config, **kwargs) elif type(config) in cls._model_mapping.keys(): model_class = _get_model_class(config, cls._model_mapping) @@ -549,6 +555,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): ) _ = hub_kwargs.pop("code_revision", None) cls.register(config.__class__, model_class, exist_ok=True) + model_class = add_generation_mixin_to_remote_model(model_class) return model_class.from_pretrained( pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs ) @@ -698,6 +705,34 @@ def getattribute_from_module(module, attr): raise ValueError(f"Could not find {attr} in {transformers_module}!") +def add_generation_mixin_to_remote_model(model_class): + """ + Adds `GenerationMixin` to the inheritance of `model_class`, if `model_class` is a PyTorch model. + + This function is used for backwards compatibility purposes: in v4.45, we've started a deprecation cycle to make + `PreTrainedModel` stop inheriting from `GenerationMixin`. Without this function, older models dynamically loaded + from the Hub may not have the `generate` method after we remove the inheritance. + """ + # 1. If it is not a PT model (i.e. doesn't inherit Module), do nothing + if "torch.nn.modules.module.Module" not in str(model_class.__mro__): + return model_class + + # 2. If it already **directly** inherits from GenerationMixin, do nothing + if "GenerationMixin" in str(model_class.__bases__): + return model_class + + # 3. Prior to v4.45, we could detect whether a model was `generate`-compatible if it had its own `generate` and/or + # `prepare_inputs_for_generation` method. + has_custom_generate = "GenerationMixin" not in str(getattr(model_class, "generate")) + has_custom_prepare_inputs = "GenerationMixin" not in str(getattr(model_class, "prepare_inputs_for_generation")) + if has_custom_generate or has_custom_prepare_inputs: + model_class_with_generation_mixin = type( + model_class.__name__, (model_class, GenerationMixin), {**model_class.__dict__} + ) + return model_class_with_generation_mixin + return model_class + + class _LazyAutoMapping(OrderedDict): """ " A mapping config to object (model or tokenizer for instance) that will load keys and values when it is accessed. diff --git a/src/transformers/models/bark/modeling_bark.py b/src/transformers/models/bark/modeling_bark.py index 5aad7b23a8a672..3102ada542d57d 100644 --- a/src/transformers/models/bark/modeling_bark.py +++ b/src/transformers/models/bark/modeling_bark.py @@ -22,6 +22,7 @@ from torch import nn from torch.nn import functional as F +from ...generation import GenerationMixin from ...generation.logits_process import ( AlternatingCodebooksLogitsProcessor, BarkEosPrioritizerLogitsProcessor, @@ -546,7 +547,7 @@ def device(self) -> torch.device: # GPT2-like autoregressive model -class BarkCausalModel(BarkPreTrainedModel): +class BarkCausalModel(BarkPreTrainedModel, GenerationMixin): config_class = BarkSubModelConfig def __init__(self, config): diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index fa928d05caa89b..2e4e6dcaeb2d11 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -25,6 +25,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa, @@ -1557,7 +1558,7 @@ def forward( @add_start_docstrings( "The BART Model with a language modeling head. Can be used for summarization.", BART_START_DOCSTRING ) -class BartForConditionalGeneration(BartPreTrainedModel): +class BartForConditionalGeneration(BartPreTrainedModel, GenerationMixin): base_model_prefix = "model" _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] _keys_to_ignore_on_load_missing = ["final_logits_bias"] @@ -2010,7 +2011,7 @@ def forward(self, *args, **kwargs): """, BART_START_DOCSTRING, ) -class BartForCausalLM(BartPreTrainedModel): +class BartForCausalLM(BartPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 93d6d469b51226..b62746da5c6f15 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -28,6 +28,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa, @@ -1280,7 +1281,7 @@ def forward( @add_start_docstrings( """Bert Model with a `language modeling` head on top for CLM fine-tuning.""", BERT_START_DOCSTRING ) -class BertLMHeadModel(BertPreTrainedModel): +class BertLMHeadModel(BertPreTrainedModel, GenerationMixin): _tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"] def __init__(self, config): diff --git a/src/transformers/models/bert_generation/modeling_bert_generation.py b/src/transformers/models/bert_generation/modeling_bert_generation.py index a5fb3d0531153e..8496d1f6072f02 100755 --- a/src/transformers/models/bert_generation/modeling_bert_generation.py +++ b/src/transformers/models/bert_generation/modeling_bert_generation.py @@ -23,6 +23,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer @@ -863,7 +864,7 @@ def _tie_weights(self): """BertGeneration Model with a `language modeling` head on top for CLM fine-tuning.""", BERT_GENERATION_START_DOCSTRING, ) -class BertGenerationDecoder(BertGenerationPreTrainedModel): +class BertGenerationDecoder(BertGenerationPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"] def __init__(self, config): diff --git a/src/transformers/models/big_bird/modeling_big_bird.py b/src/transformers/models/big_bird/modeling_big_bird.py index a6b1660d5ae1e2..41045cb5f0001f 100755 --- a/src/transformers/models/big_bird/modeling_big_bird.py +++ b/src/transformers/models/big_bird/modeling_big_bird.py @@ -26,6 +26,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -2495,7 +2496,7 @@ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_ @add_start_docstrings( """BigBird Model with a `language modeling` head on top for CLM fine-tuning.""", BIG_BIRD_START_DOCSTRING ) -class BigBirdForCausalLM(BigBirdPreTrainedModel): +class BigBirdForCausalLM(BigBirdPreTrainedModel, GenerationMixin): _tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"] def __init__(self, config): diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index 9f8e3cd19cd835..e26dce1edfc20f 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -24,6 +24,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutput, @@ -2436,7 +2437,7 @@ def forward( BIGBIRD_PEGASUS_START_DOCSTRING, ) # Copied from transformers.models.bart.modeling_bart.BartForConditionalGeneration with Bart->BigBirdPegasus, BART->BIGBIRD_PEGASUS -class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel): +class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel, GenerationMixin): base_model_prefix = "model" _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] _keys_to_ignore_on_load_missing = ["final_logits_bias"] @@ -2882,7 +2883,7 @@ def forward(self, *args, **kwargs): return self.decoder(*args, **kwargs) -class BigBirdPegasusForCausalLM(BigBirdPegasusPreTrainedModel): +class BigBirdPegasusForCausalLM(BigBirdPegasusPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py index 16f7aab5c3dfa5..7ad1dcbd661c32 100755 --- a/src/transformers/models/biogpt/modeling_biogpt.py +++ b/src/transformers/models/biogpt/modeling_biogpt.py @@ -23,6 +23,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, @@ -719,7 +720,7 @@ def forward( @add_start_docstrings( """BioGPT Model with a `language modeling` head on top for CLM fine-tuning.""", BIOGPT_START_DOCSTRING ) -class BioGptForCausalLM(BioGptPreTrainedModel): +class BioGptForCausalLM(BioGptPreTrainedModel, GenerationMixin): _tied_weights_keys = ["output_projection.weight"] def __init__(self, config): diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index 12d259fde71ec5..4ea5926d854c98 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -26,6 +26,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutput, @@ -1196,7 +1197,7 @@ def forward( @add_start_docstrings( "The Blenderbot Model with a language modeling head. Can be used for summarization.", BLENDERBOT_START_DOCSTRING ) -class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel): +class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel, GenerationMixin): base_model_prefix = "model" _keys_to_ignore_on_load_missing = ["final_logits_bias"] _tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight", "lm_head.weight"] @@ -1397,7 +1398,7 @@ def forward(self, *args, **kwargs): # Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->Blenderbot, facebook/bart-base->facebook/blenderbot-400M-distill -class BlenderbotForCausalLM(BlenderbotPreTrainedModel): +class BlenderbotForCausalLM(BlenderbotPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index aa0e38bd8e9148..3e378f483a317a 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -24,6 +24,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutput, @@ -1163,7 +1164,7 @@ def forward( "The BlenderbotSmall Model with a language modeling head. Can be used for summarization.", BLENDERBOT_SMALL_START_DOCSTRING, ) -class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel): +class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel, GenerationMixin): base_model_prefix = "model" _keys_to_ignore_on_load_missing = ["final_logits_bias"] _tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight", "lm_head.weight"] @@ -1349,7 +1350,7 @@ def forward(self, *args, **kwargs): # Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->BlenderbotSmall, facebook/bart-base->facebook/blenderbot_small-90M -class BlenderbotSmallForCausalLM(BlenderbotSmallPreTrainedModel): +class BlenderbotSmallForCausalLM(BlenderbotSmallPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/blip/modeling_blip.py b/src/transformers/models/blip/modeling_blip.py index 2392961037f211..aef9b8cebec91f 100644 --- a/src/transformers/models/blip/modeling_blip.py +++ b/src/transformers/models/blip/modeling_blip.py @@ -24,6 +24,7 @@ from torch.nn.functional import normalize from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_utils import PreTrainedModel from ...utils import ( @@ -1035,7 +1036,7 @@ def forward( """, BLIP_START_DOCSTRING, ) -class BlipForConditionalGeneration(BlipPreTrainedModel): +class BlipForConditionalGeneration(BlipPreTrainedModel, GenerationMixin): config_class = BlipConfig _tied_weights_keys = ["text_decoder.cls.predictions.decoder.bias"] main_input_name = "pixel_values" diff --git a/src/transformers/models/blip/modeling_blip_text.py b/src/transformers/models/blip/modeling_blip_text.py index a800ba89825dcb..78384e6ce2f74b 100644 --- a/src/transformers/models/blip/modeling_blip_text.py +++ b/src/transformers/models/blip/modeling_blip_text.py @@ -23,6 +23,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -808,7 +809,7 @@ def forward( # Adapted from https://github.com/salesforce/BLIP/blob/main/models/med.py#L811 -class BlipTextLMHeadModel(BlipTextPreTrainedModel): +class BlipTextLMHeadModel(BlipTextPreTrainedModel, GenerationMixin): def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index 8c3b5254ea8b90..0b33572a689c2a 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -24,6 +24,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -2006,7 +2007,7 @@ def forward( """, BLIP_2_START_DOCSTRING, ) -class Blip2ForConditionalGeneration(Blip2PreTrainedModel): +class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin): config_class = Blip2Config main_input_name = "pixel_values" diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index b5b221b6b37f83..0992a5519f953d 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -26,6 +26,7 @@ from ...cache_utils import Cache, DynamicCache, StaticCache from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, @@ -860,7 +861,7 @@ def _update_causal_mask( """, BLOOM_START_DOCSTRING, ) -class BloomForCausalLM(BloomPreTrainedModel): +class BloomForCausalLM(BloomPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config: BloomConfig): diff --git a/src/transformers/models/camembert/modeling_camembert.py b/src/transformers/models/camembert/modeling_camembert.py index 0d12c800c156f7..95540f96d3b6f6 100644 --- a/src/transformers/models/camembert/modeling_camembert.py +++ b/src/transformers/models/camembert/modeling_camembert.py @@ -25,6 +25,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, gelu +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa, @@ -1544,7 +1545,7 @@ def forward( """CamemBERT Model with a `language modeling` head on top for CLM fine-tuning.""", CAMEMBERT_START_DOCSTRING ) # Copied from transformers.models.roberta.modeling_roberta.RobertaForCausalLM with Roberta->Camembert, ROBERTA->CAMEMBERT, FacebookAI/roberta-base->almanach/camembert-base -class CamembertForCausalLM(CamembertPreTrainedModel): +class CamembertForCausalLM(CamembertPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"] def __init__(self, config): diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index 23334311ca9511..c631181f00c59e 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -26,6 +26,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import _flash_attention_forward from ...modeling_outputs import ( @@ -1496,7 +1497,7 @@ def _update_causal_mask( "Chameleon Model with a head on top used for outputting logits for next token prediction.", CHAMELEON_START_DOCSTRING, ) -class ChameleonForConditionalGeneration(ChameleonPreTrainedModel): +class ChameleonForConditionalGeneration(ChameleonPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/clvp/modeling_clvp.py b/src/transformers/models/clvp/modeling_clvp.py index 479b0fac2b0490..f438226064ec2d 100644 --- a/src/transformers/models/clvp/modeling_clvp.py +++ b/src/transformers/models/clvp/modeling_clvp.py @@ -26,7 +26,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN -from ...generation import GenerationConfig +from ...generation import GenerationConfig, GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutput, @@ -1278,7 +1278,7 @@ def forward( "The CLVP decoder model with a language modelling head on top.", CLVP_START_DOCSTRING, ) -class ClvpForCausalLM(ClvpPreTrainedModel): +class ClvpForCausalLM(ClvpPreTrainedModel, GenerationMixin): def __init__(self, config): super().__init__(config) @@ -1509,7 +1509,7 @@ def _reorder_cache( "together to filter out the best speech_ids.", CLVP_START_DOCSTRING, ) -class ClvpModelForConditionalGeneration(ClvpPreTrainedModel): +class ClvpModelForConditionalGeneration(ClvpPreTrainedModel, GenerationMixin): config_class = ClvpConfig def __init__(self, config: ClvpConfig): diff --git a/src/transformers/models/codegen/modeling_codegen.py b/src/transformers/models/codegen/modeling_codegen.py index be57838975c0b0..7d6f64d6461a2e 100644 --- a/src/transformers/models/codegen/modeling_codegen.py +++ b/src/transformers/models/codegen/modeling_codegen.py @@ -23,6 +23,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_utils import PreTrainedModel @@ -702,7 +703,7 @@ def _update_causal_mask( """, CODEGEN_START_DOCSTRING, ) -class CodeGenForCausalLM(CodeGenPreTrainedModel): +class CodeGenForCausalLM(CodeGenPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index cb1b3f885798c8..12586af23f0d7b 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -32,6 +32,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutputWithPast, @@ -1068,7 +1069,7 @@ def _update_causal_mask( # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with Llama->Cohere -class CohereForCausalLM(CoherePreTrainedModel): +class CohereForCausalLM(CoherePreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] # Ignore copy diff --git a/src/transformers/models/cpmant/modeling_cpmant.py b/src/transformers/models/cpmant/modeling_cpmant.py index c8a313505251fb..964d0bbfd1456b 100755 --- a/src/transformers/models/cpmant/modeling_cpmant.py +++ b/src/transformers/models/cpmant/modeling_cpmant.py @@ -24,6 +24,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_utils import PreTrainedModel from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging @@ -736,7 +737,7 @@ def forward( """, CPMANT_START_DOCSTRING, ) -class CpmAntForCausalLM(CpmAntPreTrainedModel): +class CpmAntForCausalLM(CpmAntPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config: CpmAntConfig): diff --git a/src/transformers/models/ctrl/modeling_ctrl.py b/src/transformers/models/ctrl/modeling_ctrl.py index bbf3b10a62ec1c..6d921621d47dcb 100644 --- a/src/transformers/models/ctrl/modeling_ctrl.py +++ b/src/transformers/models/ctrl/modeling_ctrl.py @@ -22,6 +22,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from ...generation import GenerationMixin from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutput from ...modeling_utils import PreTrainedModel from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_linear_layer @@ -503,7 +504,7 @@ def forward( """, CTRL_START_DOCSTRING, ) -class CTRLLMHeadModel(CTRLPreTrainedModel): +class CTRLLMHeadModel(CTRLPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/data2vec/modeling_data2vec_text.py b/src/transformers/models/data2vec/modeling_data2vec_text.py index a41fdfb56ed10a..fcddeab7a595ea 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_text.py +++ b/src/transformers/models/data2vec/modeling_data2vec_text.py @@ -23,6 +23,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, gelu +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -866,7 +867,7 @@ def forward( @add_start_docstrings( """Data2VecText Model with a `language modeling` head on top for CLM fine-tuning.""", DATA2VECTEXT_START_DOCSTRING ) -class Data2VecTextForCausalLM(Data2VecTextPreTrainedModel): +class Data2VecTextForCausalLM(Data2VecTextPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"] def __init__(self, config): diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index 7263713c084007..46de60e24f1a04 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -23,6 +23,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast from ...modeling_utils import PreTrainedModel @@ -1227,7 +1228,7 @@ def _update_causal_mask( @add_start_docstrings("The DBRX Model transformer for causal language modeling.", DBRX_START_DOCSTRING) -class DbrxForCausalLM(DbrxPreTrainedModel): +class DbrxForCausalLM(DbrxPreTrainedModel, GenerationMixin): def __init__(self, config: DbrxConfig): super().__init__(config) self.transformer = DbrxModel(config) diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py index dd017170bef9a3..a200d716d451e2 100644 --- a/src/transformers/models/electra/modeling_electra.py +++ b/src/transformers/models/electra/modeling_electra.py @@ -25,6 +25,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, get_activation +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutputWithCrossAttentions, BaseModelOutputWithPastAndCrossAttentions, @@ -1524,7 +1525,7 @@ def forward( @add_start_docstrings( """ELECTRA Model with a `language modeling` head on top for CLM fine-tuning.""", ELECTRA_START_DOCSTRING ) -class ElectraForCausalLM(ElectraPreTrainedModel): +class ElectraForCausalLM(ElectraPreTrainedModel, GenerationMixin): _tied_weights_keys = ["generator_lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/ernie/modeling_ernie.py b/src/transformers/models/ernie/modeling_ernie.py index 6a0a26a5cbe56b..6d81c97da02302 100644 --- a/src/transformers/models/ernie/modeling_ernie.py +++ b/src/transformers/models/ernie/modeling_ernie.py @@ -25,6 +25,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -1081,7 +1082,7 @@ def forward( @add_start_docstrings( """Ernie Model with a `language modeling` head on top for CLM fine-tuning.""", ERNIE_START_DOCSTRING ) -class ErnieForCausalLM(ErniePreTrainedModel): +class ErnieForCausalLM(ErniePreTrainedModel, GenerationMixin): _tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"] # Copied from transformers.models.bert.modeling_bert.BertLMHeadModel.__init__ with BertLMHeadModel->ErnieForCausalLM,Bert->Ernie,bert->ernie diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 9a37fe22e1779e..270845c20aae2e 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -25,6 +25,7 @@ from ...activations import get_activation from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( AttentionMaskConverter, ) @@ -1239,7 +1240,7 @@ def _update_causal_mask( "The Falcon Model transformer with a language modeling head on top (linear layer with weights tied to the input embeddings).", FALCON_START_DOCSTRING, ) -class FalconForCausalLM(FalconPreTrainedModel): +class FalconForCausalLM(FalconPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config: FalconConfig): diff --git a/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py b/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py index f682f75f222ef6..011197d9854273 100644 --- a/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py +++ b/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py @@ -25,6 +25,7 @@ from ...activations import ACT2FN from ...cache_utils import MambaCache +from ...generation import GenerationMixin from ...modeling_utils import PreTrainedModel from ...utils import ( ModelOutput, @@ -717,7 +718,7 @@ def forward( FALCONMAMBA_START_DOCSTRING, ) # Copied from transformers.models.mamba.modeling_mamba.MambaForCausalLM with MAMBA->FALCONMAMBA,Mamba->FalconMamba,mamba->falcon_mamba,FalconMambaCache->MambaCache -class FalconMambaForCausalLM(FalconMambaPreTrainedModel): +class FalconMambaForCausalLM(FalconMambaPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/flaubert/modeling_flaubert.py b/src/transformers/models/flaubert/modeling_flaubert.py index 50c6f7ede2229f..ef1501e780350d 100644 --- a/src/transformers/models/flaubert/modeling_flaubert.py +++ b/src/transformers/models/flaubert/modeling_flaubert.py @@ -25,6 +25,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import gelu +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutput, MaskedLMOutput, @@ -644,7 +645,7 @@ def forward( FLAUBERT_START_DOCSTRING, ) # Copied transformers.models.xlm.modeling_xlm.XLMWithLMHeadModel with XLM_INPUTS->FLAUBERT_INPUTS,XLM->Flaubert -class FlaubertWithLMHeadModel(FlaubertPreTrainedModel): +class FlaubertWithLMHeadModel(FlaubertPreTrainedModel, GenerationMixin): _tied_weights_keys = ["pred_layer.proj.weight"] def __init__(self, config): diff --git a/src/transformers/models/fsmt/modeling_fsmt.py b/src/transformers/models/fsmt/modeling_fsmt.py index 179408aba38ebe..4d50f9bb5925b4 100644 --- a/src/transformers/models/fsmt/modeling_fsmt.py +++ b/src/transformers/models/fsmt/modeling_fsmt.py @@ -35,6 +35,7 @@ from torch.nn import CrossEntropyLoss, LayerNorm from ...activations import ACT2FN +from ...generation import GenerationMixin from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...modeling_outputs import ( BaseModelOutput, @@ -1173,7 +1174,7 @@ def set_output_embeddings(self, value): @add_start_docstrings( "The FSMT Model with a language modeling head. Can be used for summarization.", FSMT_START_DOCSTRING ) -class FSMTForConditionalGeneration(PretrainedFSMTModel): +class FSMTForConditionalGeneration(PretrainedFSMTModel, GenerationMixin): base_model_prefix = "model" _tied_weights_keys = ["decoder.embed_tokens.weight", "decoder.output_projection.weight"] diff --git a/src/transformers/models/fuyu/modeling_fuyu.py b/src/transformers/models/fuyu/modeling_fuyu.py index 089313b03b7b60..0aabbf6b3654b7 100644 --- a/src/transformers/models/fuyu/modeling_fuyu.py +++ b/src/transformers/models/fuyu/modeling_fuyu.py @@ -20,6 +20,7 @@ import torch.utils.checkpoint from torch import nn +from ...generation import GenerationMixin from ...modeling_outputs import CausalLMOutputWithPast from ...modeling_utils import PreTrainedModel from ...models.auto.modeling_auto import AutoModelForCausalLM @@ -145,7 +146,7 @@ def _init_weights(self, module): "Fuyu Model with a language modeling head on top for causal language model conditioned on image patches and text.", FUYU_START_DOCSTRING, ) -class FuyuForCausalLM(FuyuPreTrainedModel): +class FuyuForCausalLM(FuyuPreTrainedModel, GenerationMixin): def __init__(self, config: FuyuConfig): super().__init__(config) self.padding_idx = config.pad_token_id diff --git a/src/transformers/models/gemma/diff_gemma.py b/src/transformers/models/gemma/diff_gemma.py index 36f2d1c594abaa..dcc43bc74aece9 100644 --- a/src/transformers/models/gemma/diff_gemma.py +++ b/src/transformers/models/gemma/diff_gemma.py @@ -34,6 +34,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin from ...modeling_flash_attention_utils import _flash_attention_forward from ...modeling_outputs import CausalLMOutputWithPast from ...pytorch_utils import ALL_LAYERNORM_LAYERS @@ -527,7 +528,7 @@ def forward( # Example where we ony modify the docstring and call super -class GemmaForCausalLM(LlamaForCausalLM): +class GemmaForCausalLM(LlamaForCausalLM, GenerationMixin): def forward( self, input_ids: torch.LongTensor = None, diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index dd4c899d13d465..8d9bb88686de24 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -29,6 +29,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import _flash_attention_forward from ...modeling_outputs import ( @@ -988,7 +989,7 @@ def _update_causal_mask( return causal_mask -class GemmaForCausalLM(GemmaPreTrainedModel): +class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/gemma2/diff_gemma2.py b/src/transformers/models/gemma2/diff_gemma2.py index 30f371a1b61267..a66ce3160b5fd1 100644 --- a/src/transformers/models/gemma2/diff_gemma2.py +++ b/src/transformers/models/gemma2/diff_gemma2.py @@ -33,6 +33,7 @@ ) from ...cache_utils import Cache +from ...generation import GenerationMixin from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging @@ -473,7 +474,7 @@ def _update_causal_mask( return causal_mask -class Gemma2ForCausalLM(GemmaForCausalLM): +class Gemma2ForCausalLM(GemmaForCausalLM, GenerationMixin): def forward( self, input_ids: torch.LongTensor = None, diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index be964c9aed018a..6b55500739b40b 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -28,6 +28,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, HybridCache +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -931,7 +932,7 @@ def _update_causal_mask( return causal_mask -class Gemma2ForCausalLM(Gemma2PreTrainedModel): +class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/git/modeling_git.py b/src/transformers/models/git/modeling_git.py index 7471eec811a065..59d3a406ec3535 100644 --- a/src/transformers/models/git/modeling_git.py +++ b/src/transformers/models/git/modeling_git.py @@ -27,6 +27,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...file_utils import ModelOutput +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_outputs import ( BaseModelOutput, @@ -1324,7 +1325,7 @@ def forward( @add_start_docstrings( """GIT Model with a `language modeling` head on top for autoregressive language modeling.""", GIT_START_DOCSTRING ) -class GitForCausalLM(GitPreTrainedModel): +class GitForCausalLM(GitPreTrainedModel, GenerationMixin): _tied_weights_keys = ["output.weight"] def __init__(self, config): diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 8dfbfb9064444d..e99f4b126246d8 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -28,6 +28,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, @@ -1182,7 +1183,7 @@ def forward( """, GPT2_START_DOCSTRING, ) -class GPT2LMHeadModel(GPT2PreTrainedModel): +class GPT2LMHeadModel(GPT2PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): @@ -1384,7 +1385,7 @@ def _reorder_cache( """, GPT2_START_DOCSTRING, ) -class GPT2DoubleHeadsModel(GPT2PreTrainedModel): +class GPT2DoubleHeadsModel(GPT2PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 0f927a72469dc9..ca1c03fcd9f911 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -22,6 +22,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, @@ -1040,7 +1041,7 @@ def forward( """, GPT_BIGCODE_START_DOCSTRING, ) -class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel): +class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 28309f7738ebe2..2fae1753154ccf 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -24,6 +24,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutputWithPast, @@ -917,7 +918,7 @@ def _update_causal_mask( """, GPT_NEO_START_DOCSTRING, ) -class GPTNeoForCausalLM(GPTNeoPreTrainedModel): +class GPTNeoForCausalLM(GPTNeoPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 274c571fa8939c..c1b2aa899985c8 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -30,6 +30,7 @@ add_start_docstrings_to_model_forward, replace_return_docstrings, ) +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutputWithPast, @@ -1110,7 +1111,7 @@ def _update_causal_mask( @add_start_docstrings( """GPTNeoX Model with a `language modeling` head on top for CLM fine-tuning.""", GPT_NEOX_START_DOCSTRING ) -class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel): +class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel, GenerationMixin): _tied_weights_keys = ["embed_out.weight"] def __init__(self, config): diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py index 048e108a8ec2d7..3db2099511bc6b 100755 --- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -25,6 +25,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS @@ -815,7 +816,7 @@ def _update_causal_mask( """GPTNeoXJapanese Model with a `language modeling` head on top for Classifier Model fine-tuning.""", GPT_NEOX_JAPANESE_START_DOCSTRING, ) -class GPTNeoXJapaneseForCausalLM(GPTNeoXJapanesePreTrainedModel): +class GPTNeoXJapaneseForCausalLM(GPTNeoXJapanesePreTrainedModel, GenerationMixin): _tied_weights_keys = ["embed_out.weight"] def __init__(self, config): diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index 84f6d985f76474..9eeb26c5e403e0 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -25,6 +25,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutputWithPast, @@ -1011,7 +1012,7 @@ def _update_causal_mask( """, GPTJ_START_DOCSTRING, ) -class GPTJForCausalLM(GPTJPreTrainedModel): +class GPTJForCausalLM(GPTJPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index f62de411a4fa5f..9a8d4570e7befe 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -22,6 +22,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import _flash_attention_forward from ...modeling_outputs import ( @@ -1004,7 +1005,7 @@ def _update_causal_mask( return causal_mask -class GraniteForCausalLM(GranitePreTrainedModel): +class GraniteForCausalLM(GranitePreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->Granite diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py index 3ac462bdad34db..d724485990b938 100644 --- a/src/transformers/models/granitemoe/modeling_granitemoe.py +++ b/src/transformers/models/granitemoe/modeling_granitemoe.py @@ -23,6 +23,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import _flash_attention_forward from ...modeling_outputs import ( @@ -1234,7 +1235,7 @@ def _update_causal_mask( return causal_mask -class GraniteMoeForCausalLM(GraniteMoePreTrainedModel): +class GraniteMoeForCausalLM(GraniteMoePreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config: GraniteMoeConfig): diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index 41be300095e710..9273d91ac401ff 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -23,11 +23,12 @@ from torch import nn from torch.nn import CrossEntropyLoss -from ... import PreTrainedModel from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_outputs import BaseModelOutput, ModelOutput +from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, @@ -1450,7 +1451,7 @@ def forward( """The Idefics2 Model with a language modeling head. It is made up a SigLIP vision encoder, with a language modeling head on top. """, IDEFICS2_START_DOCSTRING, ) -class Idefics2ForConditionalGeneration(Idefics2PreTrainedModel): +class Idefics2ForConditionalGeneration(Idefics2PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/imagegpt/modeling_imagegpt.py b/src/transformers/models/imagegpt/modeling_imagegpt.py index 5d59a4ed90e4c9..a027876b43d369 100755 --- a/src/transformers/models/imagegpt/modeling_imagegpt.py +++ b/src/transformers/models/imagegpt/modeling_imagegpt.py @@ -26,6 +26,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, @@ -880,7 +881,7 @@ def forward( """, IMAGEGPT_START_DOCSTRING, ) -class ImageGPTForCausalImageModeling(ImageGPTPreTrainedModel): +class ImageGPTForCausalImageModeling(ImageGPTPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config: ImageGPTConfig): diff --git a/src/transformers/models/instructblip/modeling_instructblip.py b/src/transformers/models/instructblip/modeling_instructblip.py index ba77afe9f7c211..dff897f59d2d26 100644 --- a/src/transformers/models/instructblip/modeling_instructblip.py +++ b/src/transformers/models/instructblip/modeling_instructblip.py @@ -24,6 +24,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -1283,7 +1284,7 @@ def forward( """, INSTRUCTBLIP_START_DOCSTRING, ) -class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel): +class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, GenerationMixin): config_class = InstructBlipConfig main_input_name = "pixel_values" diff --git a/src/transformers/models/instructblipvideo/diff_instructblipvideo.py b/src/transformers/models/instructblipvideo/diff_instructblipvideo.py index 506da83c532240..be569abc9137c2 100644 --- a/src/transformers/models/instructblipvideo/diff_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/diff_instructblipvideo.py @@ -45,6 +45,7 @@ InstructBlipVisionModel, ) +from ...generation import GenerationMixin from ...utils import logging @@ -128,7 +129,7 @@ class InstructBlipVideoQFormerModel(InstructBlipQFormerModel): pass -class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGeneration): +class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGeneration, GenerationMixin): def forward( self, pixel_values: torch.FloatTensor, diff --git a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py index 8cb813e0ac571c..bcc299b1ba7831 100644 --- a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py @@ -30,6 +30,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -1292,7 +1293,7 @@ def forward( """, INSTRUCTBLIPVIDEO_START_DOCSTRING, ) -class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel): +class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel, GenerationMixin): config_class = InstructBlipVideoConfig main_input_name = "pixel_values" diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index 60e1670a3c2784..4b8630efbfa946 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -30,6 +30,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache # we need __iter__ and __len__ of pkv +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( AttentionMaskConverter, ) @@ -1424,7 +1425,7 @@ def _update_mamba_mask(self, attention_mask, cache_position): # Adapted from transformers.models.mixtral.modeling_mixtral.MixtralForCausalLM with MIXTRAL->JAMBA, Mixtral->Jamba -class JambaForCausalLM(JambaPreTrainedModel): +class JambaForCausalLM(JambaPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config: JambaConfig): diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index 7c4394d0e1a168..e9c06960499136 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -25,6 +25,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( MoeCausalLMOutputWithPast, @@ -1202,7 +1203,7 @@ def _update_causal_mask( return causal_mask -class JetMoeForCausalLM(JetMoePreTrainedModel): +class JetMoeForCausalLM(JetMoePreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/kosmos2/modeling_kosmos2.py b/src/transformers/models/kosmos2/modeling_kosmos2.py index 69641790b2db8f..90e21ed2f5582b 100644 --- a/src/transformers/models/kosmos2/modeling_kosmos2.py +++ b/src/transformers/models/kosmos2/modeling_kosmos2.py @@ -24,6 +24,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -1521,7 +1522,7 @@ def forward( """, KOSMOS2_START_DOCSTRING, ) -class Kosmos2TextForCausalLM(Kosmos2PreTrainedModel): +class Kosmos2TextForCausalLM(Kosmos2PreTrainedModel, GenerationMixin): config_class = Kosmos2TextConfig _tied_weights_keys = ["lm_head.weight"] @@ -1864,7 +1865,7 @@ def forward( """, KOSMOS2_START_DOCSTRING, ) -class Kosmos2ForConditionalGeneration(Kosmos2PreTrainedModel): +class Kosmos2ForConditionalGeneration(Kosmos2PreTrainedModel, GenerationMixin): config_class = Kosmos2Config main_input_name = "pixel_values" _tied_weights_keys = ["text_model.lm_head.weight"] diff --git a/src/transformers/models/led/modeling_led.py b/src/transformers/models/led/modeling_led.py index 41b6c0a2bea27d..f96bfd82b52638 100755 --- a/src/transformers/models/led/modeling_led.py +++ b/src/transformers/models/led/modeling_led.py @@ -25,6 +25,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, @@ -2298,7 +2299,7 @@ def forward( @add_start_docstrings( "The LED Model with a language modeling head. Can be used for summarization.", LED_START_DOCSTRING ) -class LEDForConditionalGeneration(LEDPreTrainedModel): +class LEDForConditionalGeneration(LEDPreTrainedModel, GenerationMixin): base_model_prefix = "led" _keys_to_ignore_on_load_missing = ["final_logits_bias"] _tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight", "lm_head.weight"] diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 0bc44f314b5e86..73b6bcd8b4a4d7 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -28,6 +28,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import _flash_attention_forward from ...modeling_outputs import ( @@ -1101,7 +1102,7 @@ def _update_causal_mask( return causal_mask -class LlamaForCausalLM(LlamaPreTrainedModel): +class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index eb1c55341b0784..092008873d1e27 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -21,9 +21,10 @@ import torch.utils.checkpoint from torch import nn -from ... import PreTrainedModel from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import ModelOutput +from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, @@ -237,7 +238,7 @@ def _supports_sdpa(self): """The LLAVA model which consists of a vision backbone and a language model.""", LLAVA_START_DOCSTRING, ) -class LlavaForConditionalGeneration(LlavaPreTrainedModel): +class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin): def __init__(self, config: LlavaConfig): super().__init__(config) self.vision_tower = AutoModel.from_config(config.vision_config) diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py index bf76921090b244..a96b0d89420437 100644 --- a/src/transformers/models/llava_next/modeling_llava_next.py +++ b/src/transformers/models/llava_next/modeling_llava_next.py @@ -23,10 +23,11 @@ import torch.utils.checkpoint from torch import nn -from ... import PreTrainedModel from ...activations import ACT2FN +from ...generation import GenerationMixin from ...image_processing_utils import select_best_resolution from ...modeling_outputs import ModelOutput +from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, @@ -349,7 +350,7 @@ def _supports_sdpa(self): """The LLAVA-NeXT model which consists of a vision backbone and a language model.""", LLAVA_NEXT_START_DOCSTRING, ) -class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel): +class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixin): def __init__(self, config: LlavaNextConfig): super().__init__(config) self.vision_tower = AutoModel.from_config(config.vision_config) diff --git a/src/transformers/models/llava_next_video/diff_llava_next_video.py b/src/transformers/models/llava_next_video/diff_llava_next_video.py index e765dfb95cc335..c5ca2bf00324d4 100644 --- a/src/transformers/models/llava_next_video/diff_llava_next_video.py +++ b/src/transformers/models/llava_next_video/diff_llava_next_video.py @@ -29,6 +29,7 @@ image_size_to_num_patches, ) +from ...generation import GenerationMixin from ...utils import ( logging, replace_return_docstrings, @@ -218,7 +219,7 @@ class LlavaNextVideoMultiModalProjector(LlavaNextMultiModalProjector): pass -class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration): +class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration, GenerationMixin): def __init__(self, config: LlavaNextVideoConfig, **super_kwargs): super().__init__(config, **super_kwargs) self.vision_resampler = LlavaNextVideoPooler(config) diff --git a/src/transformers/models/llava_next_video/modeling_llava_next_video.py b/src/transformers/models/llava_next_video/modeling_llava_next_video.py index 589bf346ceeb9e..7ad9e0769eb35e 100644 --- a/src/transformers/models/llava_next_video/modeling_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modeling_llava_next_video.py @@ -29,10 +29,11 @@ import torch.utils.checkpoint from torch import nn -from ... import PreTrainedModel from ...activations import ACT2FN +from ...generation import GenerationMixin from ...image_processing_utils import select_best_resolution from ...modeling_outputs import ModelOutput +from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, @@ -387,7 +388,7 @@ def _supports_sdpa(self): """The LLAVA-NeXT model which consists of a vision backbone and a language model.""", LLAVA_NEXT_VIDEO_START_DOCSTRING, ) -class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel): +class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, GenerationMixin): def __init__( self, config: LlavaNextVideoConfig, diff --git a/src/transformers/models/llava_onevision/modeling_llava_onevision.py b/src/transformers/models/llava_onevision/modeling_llava_onevision.py index 593500c2e404e1..948efbc922b70d 100644 --- a/src/transformers/models/llava_onevision/modeling_llava_onevision.py +++ b/src/transformers/models/llava_onevision/modeling_llava_onevision.py @@ -23,10 +23,11 @@ import torch.utils.checkpoint from torch import nn -from ... import PreTrainedModel from ...activations import ACT2FN +from ...generation import GenerationMixin from ...image_processing_utils import select_best_resolution from ...modeling_outputs import ModelOutput +from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, logging, @@ -358,7 +359,7 @@ def _init_weights(self, module): """The LLaVA-Onevision model which consists of a vision backbone and a language model.""", LLAVA_ONEVISION_START_DOCSTRING, ) -class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel): +class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, GenerationMixin): def __init__(self, config: LlavaOnevisionConfig): super().__init__(config) self.vision_tower = AutoModel.from_config( diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py index b2a6ed11ca5728..8f9385c0fe76ed 100644 --- a/src/transformers/models/longt5/modeling_longt5.py +++ b/src/transformers/models/longt5/modeling_longt5.py @@ -24,6 +24,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -1900,7 +1901,7 @@ def forward( @add_start_docstrings("""LONGT5 Model with a `language modeling` head on top.""", LONGT5_START_DOCSTRING) -class LongT5ForConditionalGeneration(LongT5PreTrainedModel): +class LongT5ForConditionalGeneration(LongT5PreTrainedModel, GenerationMixin): _keys_to_ignore_on_load_unexpected = [ r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight", ] diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index 23a855fff25672..86a4378da29cdb 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -22,6 +22,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( @@ -1342,7 +1343,7 @@ def forward( @add_start_docstrings( "The M2M100 Model with a language modeling head. Can be used for summarization.", M2M_100_START_DOCSTRING ) -class M2M100ForConditionalGeneration(M2M100PreTrainedModel): +class M2M100ForConditionalGeneration(M2M100PreTrainedModel, GenerationMixin): base_model_prefix = "model" _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 14a3dea1d1ccf8..6bed1caab23ab7 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -25,6 +25,7 @@ from ...activations import ACT2FN from ...cache_utils import MambaCache +from ...generation import GenerationMixin from ...modeling_utils import PreTrainedModel from ...utils import ( ModelOutput, @@ -657,7 +658,7 @@ def forward( """, MAMBA_START_DOCSTRING, ) -class MambaForCausalLM(MambaPreTrainedModel): +class MambaForCausalLM(MambaPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/mamba2/modeling_mamba2.py b/src/transformers/models/mamba2/modeling_mamba2.py index 19d53437130e24..01074af38a510b 100644 --- a/src/transformers/models/mamba2/modeling_mamba2.py +++ b/src/transformers/models/mamba2/modeling_mamba2.py @@ -24,6 +24,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_utils import PreTrainedModel from ...utils import ( ModelOutput, @@ -932,7 +933,7 @@ def forward( """, MAMBA2_START_DOCSTRING, ) -class Mamba2ForCausalLM(Mamba2PreTrainedModel): +class Mamba2ForCausalLM(Mamba2PreTrainedModel, GenerationMixin): _tied_weights_keys = [] def __init__(self, config): diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index 2045f673540f52..cb26bb11e094cd 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -25,6 +25,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutput, @@ -1224,7 +1225,7 @@ def forward( @add_start_docstrings( "The Marian Model with a language modeling head. Can be used for summarization.", MARIAN_START_DOCSTRING ) -class MarianMTModel(MarianPreTrainedModel): +class MarianMTModel(MarianPreTrainedModel, GenerationMixin): base_model_prefix = "model" _keys_to_ignore_on_load_missing = [ "final_logits_bias", @@ -1504,7 +1505,7 @@ def forward(self, *args, **kwargs): # Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->Marian, facebook/bart-base->Helsinki-NLP/opus-mt-fr-en -class MarianForCausalLM(MarianPreTrainedModel): +class MarianForCausalLM(MarianPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 9455f21b2073ff..3f2d6cb8e2ba8d 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -24,6 +24,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa, @@ -1526,7 +1527,7 @@ def forward( "The MBART Model with a language modeling head. Can be used for summarization, after fine-tuning the pretrained models.", MBART_START_DOCSTRING, ) -class MBartForConditionalGeneration(MBartPreTrainedModel): +class MBartForConditionalGeneration(MBartPreTrainedModel, GenerationMixin): base_model_prefix = "model" _keys_to_ignore_on_load_missing = ["final_logits_bias"] _tied_weights_keys = ["model.encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight", "lm_head.weight"] @@ -1967,7 +1968,7 @@ def forward(self, *args, **kwargs): # Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->MBart, facebook/bart-base->facebook/mbart-large-cc25 -class MBartForCausalLM(MBartPreTrainedModel): +class MBartForCausalLM(MBartPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/megatron_bert/modeling_megatron_bert.py b/src/transformers/models/megatron_bert/modeling_megatron_bert.py index 16641655e20395..20506f91bcbcb2 100755 --- a/src/transformers/models/megatron_bert/modeling_megatron_bert.py +++ b/src/transformers/models/megatron_bert/modeling_megatron_bert.py @@ -27,6 +27,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -1110,7 +1111,7 @@ def forward( """MegatronBert Model with a `language modeling` head on top for CLM fine-tuning.""", MEGATRON_BERT_START_DOCSTRING, ) -class MegatronBertForCausalLM(MegatronBertPreTrainedModel): +class MegatronBertForCausalLM(MegatronBertPreTrainedModel, GenerationMixin): _tied_weights_keys = ["cls.predictions.decoder"] def __init__(self, config): diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 80992734046ad6..ffa1a18307e982 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -29,6 +29,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutputWithPast, @@ -950,7 +951,7 @@ def _update_causal_mask( return causal_mask -class MistralForCausalLM(MistralPreTrainedModel): +class MistralForCausalLM(MistralPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index fcc0d66e19c4a4..a1786fbb17e3c5 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -30,6 +30,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( MoeCausalLMOutputWithPast, @@ -1186,7 +1187,7 @@ def _update_causal_mask( return causal_mask -class MixtralForCausalLM(MixtralPreTrainedModel): +class MixtralForCausalLM(MixtralPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/mpt/modeling_mpt.py b/src/transformers/models/mpt/modeling_mpt.py index 85579636dcc4cd..9c826c370b752a 100644 --- a/src/transformers/models/mpt/modeling_mpt.py +++ b/src/transformers/models/mpt/modeling_mpt.py @@ -24,6 +24,7 @@ from torch.nn import functional as F from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, @@ -500,7 +501,7 @@ def forward( """, MPT_START_DOCSTRING, ) -class MptForCausalLM(MptPreTrainedModel): +class MptForCausalLM(MptPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config: MptConfig): diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index 54943cf982dd24..6a7406f11b5b56 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -25,6 +25,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -1550,7 +1551,7 @@ def forward( @add_start_docstrings("""MT5 Model with a `language modeling` head on top.""", MT5_START_DOCSTRING) -class MT5ForConditionalGeneration(MT5PreTrainedModel): +class MT5ForConditionalGeneration(MT5PreTrainedModel, GenerationMixin): r""" Examples: diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index f720faac038e51..3109c4fc243118 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -26,9 +26,14 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN -from ...generation.configuration_utils import GenerationConfig, GenerationMode -from ...generation.logits_process import ClassifierFreeGuidanceLogitsProcessor, LogitsProcessorList -from ...generation.stopping_criteria import StoppingCriteriaList +from ...generation import ( + ClassifierFreeGuidanceLogitsProcessor, + GenerationConfig, + GenerationMixin, + GenerationMode, + LogitsProcessorList, + StoppingCriteriaList, +) from ...modeling_attn_mask_utils import ( _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa, @@ -1206,7 +1211,7 @@ def forward( "The MusicGen decoder model with a language modelling head on top.", MUSICGEN_START_DOCSTRING, ) -class MusicgenForCausalLM(MusicgenPreTrainedModel): +class MusicgenForCausalLM(MusicgenPreTrainedModel, GenerationMixin): def __init__(self, config: MusicgenDecoderConfig): super().__init__(config) @@ -1658,7 +1663,7 @@ def generate( "for music generation tasks with one or both of text and audio prompts.", MUSICGEN_START_DOCSTRING, ) -class MusicgenForConditionalGeneration(PreTrainedModel): +class MusicgenForConditionalGeneration(PreTrainedModel, GenerationMixin): config_class = MusicgenConfig base_model_prefix = "encoder_decoder" main_input_name = "input_ids" diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py index a8a8fe96098952..c8345870b2537e 100644 --- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py @@ -26,9 +26,14 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN -from ...generation.configuration_utils import GenerationConfig, GenerationMode -from ...generation.logits_process import ClassifierFreeGuidanceLogitsProcessor, LogitsProcessorList -from ...generation.stopping_criteria import StoppingCriteriaList +from ...generation import ( + ClassifierFreeGuidanceLogitsProcessor, + GenerationConfig, + GenerationMixin, + GenerationMode, + LogitsProcessorList, + StoppingCriteriaList, +) from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa from ...modeling_outputs import ( BaseModelOutputWithPast, @@ -1117,7 +1122,7 @@ def forward( MUSICGEN_MELODY_START_DOCSTRING, ) # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenForCausalLM with MUSICGEN->MUSICGEN_MELODY,Musicgen->MusicgenMelody,MusicGen->Musicgen Melody -class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel): +class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel, GenerationMixin): def __init__(self, config: MusicgenMelodyDecoderConfig): super().__init__(config) @@ -1585,7 +1590,7 @@ def generate( decoder (`Optional[MusicgenMelodyForCausalLM]`, *optional*): MusicGen Melody decoder used to generate audio codes. """, ) -class MusicgenMelodyForConditionalGeneration(PreTrainedModel): +class MusicgenMelodyForConditionalGeneration(PreTrainedModel, GenerationMixin): config_class = MusicgenMelodyConfig main_input_name = "input_ids" supports_gradient_checkpointing = True diff --git a/src/transformers/models/mvp/modeling_mvp.py b/src/transformers/models/mvp/modeling_mvp.py index 319f1760cef9df..c47c4b26b539f7 100644 --- a/src/transformers/models/mvp/modeling_mvp.py +++ b/src/transformers/models/mvp/modeling_mvp.py @@ -24,6 +24,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutput, @@ -1351,7 +1352,7 @@ def forward( @add_start_docstrings( "The MVP Model with a language modeling head. Can be used for various text generation tasks.", MVP_START_DOCSTRING ) -class MvpForConditionalGeneration(MvpPreTrainedModel): +class MvpForConditionalGeneration(MvpPreTrainedModel, GenerationMixin): _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] def __init__(self, config: MvpConfig): @@ -1791,7 +1792,7 @@ def forward(self, *args, **kwargs): return self.decoder(*args, **kwargs) -class MvpForCausalLM(MvpPreTrainedModel): +class MvpForCausalLM(MvpPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index 4d079b4dde104d..aa699853d55762 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -26,6 +26,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import _flash_attention_forward from ...modeling_outputs import ( @@ -980,7 +981,7 @@ def _update_causal_mask( # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron -class NemotronForCausalLM(NemotronPreTrainedModel): +class NemotronForCausalLM(NemotronPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py index 2bec0fb84dce56..c33844da0f55b8 100644 --- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py +++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -22,6 +22,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( @@ -1604,7 +1605,7 @@ def forward( @add_start_docstrings( "The NllbMoe Model with a language modeling head. Can be used for summarization.", NLLB_MOE_START_DOCSTRING ) -class NllbMoeForConditionalGeneration(NllbMoePreTrainedModel): +class NllbMoeForConditionalGeneration(NllbMoePreTrainedModel, GenerationMixin): base_model_prefix = "model" _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 03fa524532a02e..a44b7d2a0a4c4d 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -30,6 +30,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutputWithPast, @@ -1022,7 +1023,7 @@ def _update_causal_mask( # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->OLMO,Llama->Olmo -class OlmoForCausalLM(OlmoPreTrainedModel): +class OlmoForCausalLM(OlmoPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/olmoe/modeling_olmoe.py b/src/transformers/models/olmoe/modeling_olmoe.py index 2cbde7dc863169..d30cace3a7055d 100644 --- a/src/transformers/models/olmoe/modeling_olmoe.py +++ b/src/transformers/models/olmoe/modeling_olmoe.py @@ -22,6 +22,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( MoeCausalLMOutputWithPast, @@ -1173,7 +1174,7 @@ def _update_causal_mask( return causal_mask -class OlmoeForCausalLM(OlmoePreTrainedModel): +class OlmoeForCausalLM(OlmoePreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/openai/modeling_openai.py b/src/transformers/models/openai/modeling_openai.py index 2b24850f3f0c89..0aa02a6f5d8424 100644 --- a/src/transformers/models/openai/modeling_openai.py +++ b/src/transformers/models/openai/modeling_openai.py @@ -26,6 +26,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import gelu_new, silu +from ...generation import GenerationMixin from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput from ...modeling_utils import PreTrainedModel, SequenceSummary from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer @@ -524,7 +525,7 @@ def forward( """, OPENAI_GPT_START_DOCSTRING, ) -class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel): +class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index 8f058171778efc..f7782b8f6172b9 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -22,6 +22,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutputWithPast, @@ -882,7 +883,7 @@ def forward( ) -class OPTForCausalLM(OPTPreTrainedModel): +class OPTForCausalLM(OPTPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/paligemma/modeling_paligemma.py b/src/transformers/models/paligemma/modeling_paligemma.py index 48fffb6b428df7..b5fddce1d6a914 100644 --- a/src/transformers/models/paligemma/modeling_paligemma.py +++ b/src/transformers/models/paligemma/modeling_paligemma.py @@ -22,6 +22,7 @@ from torch import nn from ...cache_utils import Cache, StaticCache +from ...generation import GenerationMixin from ...modeling_utils import PreTrainedModel from ...utils import ( ModelOutput, @@ -302,7 +303,7 @@ def _supports_sdpa(self): """The PALIGEMMA model which consists of a vision backbone and a language model.""", PALIGEMMA_START_DOCSTRING, ) -class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel): +class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixin): def __init__(self, config: PaliGemmaConfig): super().__init__(config) self.vision_tower = AutoModel.from_config(config=config.vision_config) diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index 42cef3a63558e2..03d1574e9be28e 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -25,6 +25,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutput, @@ -1244,7 +1245,7 @@ def forward( @add_start_docstrings( "The PEGASUS Model with a language modeling head. Can be used for summarization.", PEGASUS_START_DOCSTRING ) -class PegasusForConditionalGeneration(PegasusPreTrainedModel): +class PegasusForConditionalGeneration(PegasusPreTrainedModel, GenerationMixin): base_model_prefix = "model" _keys_to_ignore_on_load_missing = ["final_logits_bias"] _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] @@ -1456,7 +1457,7 @@ def forward(self, *args, **kwargs): return self.decoder(*args, **kwargs) -class PegasusForCausalLM(PegasusPreTrainedModel): +class PegasusForCausalLM(PegasusPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index 6d9072777bf634..77c0b32e6433c4 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -25,6 +25,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutput, @@ -1464,7 +1465,7 @@ def forward( @add_start_docstrings("The PEGASUS-X for conditional generation (e.g. summarization).", PEGASUS_X_START_DOCSTRING) -class PegasusXForConditionalGeneration(PegasusXPreTrainedModel): +class PegasusXForConditionalGeneration(PegasusXPreTrainedModel, GenerationMixin): base_model_prefix = "model" _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index a6fd2284afb691..4f122e14284d83 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -29,6 +29,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutputWithPast, @@ -847,7 +848,7 @@ def _update_causal_mask( return causal_mask -class PersimmonForCausalLM(PersimmonPreTrainedModel): +class PersimmonForCausalLM(PersimmonPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with LLAMA->PERSIMMON,Llama->Persimmon diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 4d0a076b5f9a33..03ed19bc34ac84 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -26,6 +26,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutputWithPast, @@ -1139,7 +1140,7 @@ def _update_causal_mask( return causal_mask -class PhiForCausalLM(PhiPreTrainedModel): +class PhiForCausalLM(PhiPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->Phi,bias=False->bias=True diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index e0ca84be184843..12ee9f017f8149 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -26,6 +26,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutputWithPast, @@ -1160,7 +1161,7 @@ def _update_causal_mask( return causal_mask -class Phi3ForCausalLM(Phi3PreTrainedModel): +class Phi3ForCausalLM(Phi3PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->Phi3 diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py index 94d882c80566ad..f209d7d8828785 100644 --- a/src/transformers/models/pix2struct/modeling_pix2struct.py +++ b/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -22,6 +22,7 @@ from torch import nn from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPooling, @@ -1553,7 +1554,7 @@ def forward( "A conditional generation model with a language modeling head. Can be used for sequence generation tasks.", PIX2STRUCT_START_DOCSTRING, ) -class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel): +class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel, GenerationMixin): config_class = Pix2StructConfig main_input_name = "flattened_patches" _tied_weights_keys = ["decoder.lm_head.weight"] diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index 93d91e160089e1..d15e079770a3ed 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -24,6 +24,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa, @@ -1254,7 +1255,7 @@ def forward( "The PLBART Model with a language modeling head. Can be used for code-to-text, text-to-code and code-to-code.", PLBART_START_DOCSTRING, ) -class PLBartForConditionalGeneration(PLBartPreTrainedModel): +class PLBartForConditionalGeneration(PLBartPreTrainedModel, GenerationMixin): base_model_prefix = "model" _keys_to_ignore_on_load_missing = ["final_logits_bias"] _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] @@ -1568,7 +1569,7 @@ def forward(self, *args, **kwargs): # Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->PLBart, facebook/bart-base->uclanlp/plbart-base -class PLBartForCausalLM(PLBartPreTrainedModel): +class PLBartForCausalLM(PLBartPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/pop2piano/modeling_pop2piano.py b/src/transformers/models/pop2piano/modeling_pop2piano.py index c769cff3c454ec..e6488898e8a938 100644 --- a/src/transformers/models/pop2piano/modeling_pop2piano.py +++ b/src/transformers/models/pop2piano/modeling_pop2piano.py @@ -25,6 +25,7 @@ from transformers.generation import GenerationConfig from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -1001,7 +1002,7 @@ def forward(self, feature, index_value, embedding_offset): @add_start_docstrings("""Pop2Piano Model with a `language modeling` head on top.""", Pop2Piano_START_DOCSTRING) -class Pop2PianoForConditionalGeneration(Pop2PianoPreTrainedModel): +class Pop2PianoForConditionalGeneration(Pop2PianoPreTrainedModel, GenerationMixin): _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] def __init__(self, config: Pop2PianoConfig): diff --git a/src/transformers/models/prophetnet/modeling_prophetnet.py b/src/transformers/models/prophetnet/modeling_prophetnet.py index 96fa2e2c12e52f..7d23088f6e575e 100644 --- a/src/transformers/models/prophetnet/modeling_prophetnet.py +++ b/src/transformers/models/prophetnet/modeling_prophetnet.py @@ -26,6 +26,7 @@ from torch.nn import LayerNorm from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import BaseModelOutput from ...modeling_utils import PreTrainedModel from ...utils import ( @@ -1856,7 +1857,7 @@ def forward( "The ProphetNet Model with a language modeling head. Can be used for sequence generation tasks.", PROPHETNET_START_DOCSTRING, ) -class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel): +class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel, GenerationMixin): _tied_weights_keys = ["encoder.word_embeddings.weight", "decoder.word_embeddings.weight", "lm_head.weight"] def __init__(self, config: ProphetNetConfig): @@ -2073,7 +2074,7 @@ def get_decoder(self): " language modeling.", PROPHETNET_START_DOCSTRING, ) -class ProphetNetForCausalLM(ProphetNetPreTrainedModel): +class ProphetNetForCausalLM(ProphetNetPreTrainedModel, GenerationMixin): _tied_weights_keys = [ "prophetnet.word_embeddings.weight", "prophetnet.decoder.word_embeddings.weight", diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 1e79115d34701b..10c0b6f38669da 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -29,6 +29,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutputWithPast, @@ -1078,7 +1079,7 @@ def _update_causal_mask( return causal_mask -class Qwen2ForCausalLM(Qwen2PreTrainedModel): +class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py index 14235bf0aaf64a..bf48e1c6a97ef9 100644 --- a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py +++ b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py @@ -22,10 +22,11 @@ import torch.utils.checkpoint from torch import nn -from ... import PreTrainedModel from ...activations import ACT2FN from ...cache_utils import Cache, EncoderDecoderCache, StaticCache +from ...generation import GenerationMixin from ...modeling_outputs import BaseModelOutput, ModelOutput +from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, @@ -855,7 +856,7 @@ def forward(self, audio_features): """The QWEN2AUDIO model which consists of a audio backbone and a language model.""", QWEN2AUDIO_START_DOCSTRING, ) -class Qwen2AudioForConditionalGeneration(Qwen2AudioPreTrainedModel): +class Qwen2AudioForConditionalGeneration(Qwen2AudioPreTrainedModel, GenerationMixin): def __init__(self, config: Qwen2AudioConfig): super().__init__(config) self.audio_tower = AutoModel.from_config(config.audio_config, attn_implementation=config._attn_implementation) diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index c9ee7b5f57a1f2..1b28e9baf25f66 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -30,6 +30,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( MoeCausalLMOutputWithPast, @@ -1253,7 +1254,7 @@ def _update_causal_mask( return causal_mask -class Qwen2MoeForCausalLM(Qwen2MoePreTrainedModel): +class Qwen2MoeForCausalLM(Qwen2MoePreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 02716153fddaee..938ec4d5e42362 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -31,6 +31,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( AttentionMaskConverter, ) @@ -1416,7 +1417,7 @@ def _update_causal_mask( """ -class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel): +class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py index a8f076fad79c76..e0492948998434 100644 --- a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +++ b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py @@ -24,6 +24,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import BaseModelOutputWithNoAttention, CausalLMOutput from ...modeling_utils import PreTrainedModel @@ -777,7 +778,7 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position): # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->RECURRENTGEMMA,Llama->RecurrentGemma,llama->gemma -class RecurrentGemmaForCausalLM(RecurrentGemmaPreTrainedModel): +class RecurrentGemmaForCausalLM(RecurrentGemmaPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/reformer/modeling_reformer.py b/src/transformers/models/reformer/modeling_reformer.py index 2e98a07217e682..37b675539e66cf 100755 --- a/src/transformers/models/reformer/modeling_reformer.py +++ b/src/transformers/models/reformer/modeling_reformer.py @@ -29,6 +29,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import CausalLMOutput, MaskedLMOutput, QuestionAnsweringModelOutput, SequenceClassifierOutput from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward @@ -2183,7 +2184,7 @@ def _pad_to_mult_of_chunk_length( @add_start_docstrings("""Reformer Model with a `language modeling` head on top.""", REFORMER_START_DOCSTRING) -class ReformerModelWithLMHead(ReformerPreTrainedModel): +class ReformerModelWithLMHead(ReformerPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"] def __init__(self, config): diff --git a/src/transformers/models/rembert/modeling_rembert.py b/src/transformers/models/rembert/modeling_rembert.py index 31f7e3dce4548c..99016c1be429cd 100755 --- a/src/transformers/models/rembert/modeling_rembert.py +++ b/src/transformers/models/rembert/modeling_rembert.py @@ -24,6 +24,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -1002,7 +1003,7 @@ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_ @add_start_docstrings( """RemBERT Model with a `language modeling` head on top for CLM fine-tuning.""", REMBERT_START_DOCSTRING ) -class RemBertForCausalLM(RemBertPreTrainedModel): +class RemBertForCausalLM(RemBertPreTrainedModel, GenerationMixin): _tied_weights_keys = ["cls.predictions.decoder.weight"] def __init__(self, config): diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index bbf16ec039b4cd..91500e1926d753 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -25,6 +25,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, gelu +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa, @@ -1003,7 +1004,7 @@ def forward( @add_start_docstrings( """RoBERTa Model with a `language modeling` head on top for CLM fine-tuning.""", ROBERTA_START_DOCSTRING ) -class RobertaForCausalLM(RobertaPreTrainedModel): +class RobertaForCausalLM(RobertaPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"] def __init__(self, config): diff --git a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py index 95657c260dc7a4..9ed9b11d9431df 100644 --- a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +++ b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py @@ -24,6 +24,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, gelu +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -855,7 +856,7 @@ def forward( ROBERTA_PRELAYERNORM_START_DOCSTRING, ) # Copied from transformers.models.roberta.modeling_roberta.RobertaForCausalLM with FacebookAI/roberta-base->andreasmadsen/efficient_mlm_m0.40,ROBERTA->ROBERTA_PRELAYERNORM,Roberta->RobertaPreLayerNorm,roberta->roberta_prelayernorm, RobertaPreLayerNormTokenizer->RobertaTokenizer -class RobertaPreLayerNormForCausalLM(RobertaPreLayerNormPreTrainedModel): +class RobertaPreLayerNormForCausalLM(RobertaPreLayerNormPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"] def __init__(self, config): diff --git a/src/transformers/models/roc_bert/modeling_roc_bert.py b/src/transformers/models/roc_bert/modeling_roc_bert.py index c4efbf16323eba..2969f7f1a3d0a8 100644 --- a/src/transformers/models/roc_bert/modeling_roc_bert.py +++ b/src/transformers/models/roc_bert/modeling_roc_bert.py @@ -24,6 +24,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -1403,7 +1404,7 @@ def prepare_inputs_for_generation( @add_start_docstrings( """RoCBert Model with a `language modeling` head on top for CLM fine-tuning.""", ROC_BERT_START_DOCSTRING ) -class RoCBertForCausalLM(RoCBertPreTrainedModel): +class RoCBertForCausalLM(RoCBertPreTrainedModel, GenerationMixin): _tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"] # Copied from transformers.models.bert.modeling_bert.BertLMHeadModel.__init__ with BertLMHeadModel->RoCBertForCausalLM,Bert->RoCBert,bert->roc_bert diff --git a/src/transformers/models/roformer/modeling_roformer.py b/src/transformers/models/roformer/modeling_roformer.py index 69588ff743a0e6..c98b525abe081b 100644 --- a/src/transformers/models/roformer/modeling_roformer.py +++ b/src/transformers/models/roformer/modeling_roformer.py @@ -25,6 +25,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, @@ -1033,7 +1034,7 @@ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_ @add_start_docstrings( """RoFormer Model with a `language modeling` head on top for CLM fine-tuning.""", ROFORMER_START_DOCSTRING ) -class RoFormerForCausalLM(RoFormerPreTrainedModel): +class RoFormerForCausalLM(RoFormerPreTrainedModel, GenerationMixin): _tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"] def __init__(self, config): diff --git a/src/transformers/models/rwkv/modeling_rwkv.py b/src/transformers/models/rwkv/modeling_rwkv.py index 7dec1f26e1a390..8361afbf727b64 100644 --- a/src/transformers/models/rwkv/modeling_rwkv.py +++ b/src/transformers/models/rwkv/modeling_rwkv.py @@ -25,6 +25,7 @@ from torch import nn from torch.nn import CrossEntropyLoss +from ...generation import GenerationMixin from ...modeling_utils import PreTrainedModel from ...utils import ( ModelOutput, @@ -751,7 +752,7 @@ def _bnb_4bit_dequantize_and_rescale(self, target_layer, block_id): """, RWKV_START_DOCSTRING, ) -class RwkvForCausalLM(RwkvPreTrainedModel): +class RwkvForCausalLM(RwkvPreTrainedModel, GenerationMixin): _tied_weights_keys = ["head.weight"] def __init__(self, config): diff --git a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py index ba8230ec509df8..8e226d92a10580 100755 --- a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py +++ b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py @@ -26,6 +26,7 @@ from ...activations import ACT2FN from ...deepspeed import is_deepspeed_zero3_enabled +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutput, @@ -2150,7 +2151,7 @@ def forward( embed_tokens_decoder (`nn.Embedding`, *optional*): input embedding of the decoder. """, ) -class SeamlessM4TTextToUnitForConditionalGeneration(SeamlessM4TPreTrainedModel): +class SeamlessM4TTextToUnitForConditionalGeneration(SeamlessM4TPreTrainedModel, GenerationMixin): _keys_to_ignore_on_load_missing = [ "vocoder", "speech_encoder", @@ -2664,7 +2665,7 @@ def remove_weight_norm(self): "The text-to-text SeamlessM4T Model transformer which can be used for T2TT.", SEAMLESS_M4T_START_DOCSTRING, ) -class SeamlessM4TForTextToText(SeamlessM4TPreTrainedModel): +class SeamlessM4TForTextToText(SeamlessM4TPreTrainedModel, GenerationMixin): _keys_to_ignore_on_load_missing = ["speech_encoder", "t2u_model", "vocoder"] main_input_name = "input_ids" diff --git a/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py b/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py index 2d1fde8eed69a0..aa710ad95266ff 100644 --- a/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +++ b/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py @@ -26,6 +26,7 @@ from ...activations import ACT2FN from ...deepspeed import is_deepspeed_zero3_enabled +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutput, @@ -2439,7 +2440,7 @@ def forward( embed_tokens_decoder (`nn.Embedding`, *optional*): input embedding of the decoder. """, ) -class SeamlessM4Tv2TextToUnitForConditionalGeneration(SeamlessM4Tv2PreTrainedModel): +class SeamlessM4Tv2TextToUnitForConditionalGeneration(SeamlessM4Tv2PreTrainedModel, GenerationMixin): _keys_to_ignore_on_load_missing = [ "vocoder", "speech_encoder", @@ -2922,7 +2923,7 @@ def remove_weight_norm(self): SEAMLESS_M4T_V2_START_DOCSTRING, ) # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TForTextToText with SeamlessM4T->SeamlessM4Tv2,SeamlessM4Tv2Tokenizer->SeamlessM4TTokenizer, SeamlessM4Tv2Processor->SeamlessM4TProcessor -class SeamlessM4Tv2ForTextToText(SeamlessM4Tv2PreTrainedModel): +class SeamlessM4Tv2ForTextToText(SeamlessM4Tv2PreTrainedModel, GenerationMixin): _keys_to_ignore_on_load_missing = ["speech_encoder", "t2u_model", "vocoder"] main_input_name = "input_ids" diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index 8353a172b2120f..bdd532fa25e82a 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -22,6 +22,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutput, @@ -1207,7 +1208,7 @@ def forward( "The Speech2Text Model with a language modeling head. Can be used for summarization.", SPEECH_TO_TEXT_START_DOCSTRING, ) -class Speech2TextForConditionalGeneration(Speech2TextPreTrainedModel): +class Speech2TextForConditionalGeneration(Speech2TextPreTrainedModel, GenerationMixin): base_model_prefix = "model" _tied_weights_keys = ["lm_head.weight"] diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index 13641ecb37f255..463a30fabe77c0 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -29,6 +29,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutputWithPast, @@ -1123,7 +1124,7 @@ def _update_causal_mask( # Copied from transformers.models.persimmon.modeling_persimmon.PersimmonForCausalLM with PERSIMMON->STABLELM,Persimmon->StableLm -class StableLmForCausalLM(StableLmPreTrainedModel): +class StableLmForCausalLM(StableLmPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with LLAMA->STABLELM,Llama->StableLm diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 5eaf50f090fa49..079ad1298fb999 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -29,6 +29,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutputWithPast, @@ -1053,7 +1054,7 @@ def _update_causal_mask( # Copied from transformers.models.qwen2.modeling_qwen2.Qwen2ForCausalLM with QWEN2->STARCODER2,Qwen2->Starcoder2 -class Starcoder2ForCausalLM(Starcoder2PreTrainedModel): +class Starcoder2ForCausalLM(Starcoder2PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index c5797d4573b781..96b6c7334b1535 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -24,6 +24,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import ( MoEModelOutput, MoEModelOutputWithPastAndCrossAttentions, @@ -1456,7 +1457,7 @@ def forward( @add_start_docstrings( """SWITCH_TRANSFORMERS Model with a `language modeling` head on top.""", SWITCH_TRANSFORMERS_START_DOCSTRING ) -class SwitchTransformersForConditionalGeneration(SwitchTransformersPreTrainedModel): +class SwitchTransformersForConditionalGeneration(SwitchTransformersPreTrainedModel, GenerationMixin): _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] def __init__(self, config: SwitchTransformersConfig): diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index a90101924c5ba6..43e3f3afa4a837 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -25,6 +25,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -1542,7 +1543,7 @@ def forward( @add_start_docstrings("""T5 Model with a `language modeling` head on top.""", T5_START_DOCSTRING) -class T5ForConditionalGeneration(T5PreTrainedModel): +class T5ForConditionalGeneration(T5PreTrainedModel, GenerationMixin): _keys_to_ignore_on_load_unexpected = [ "decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight", ] diff --git a/src/transformers/models/trocr/modeling_trocr.py b/src/transformers/models/trocr/modeling_trocr.py index 04eb40ab2a2f47..67b97cf9c85223 100644 --- a/src/transformers/models/trocr/modeling_trocr.py +++ b/src/transformers/models/trocr/modeling_trocr.py @@ -23,6 +23,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions from ...modeling_utils import PreTrainedModel @@ -736,7 +737,7 @@ def forward(self, *args, **kwargs): " [`VisionEncoderDecoder`].", TROCR_START_DOCSTRING, ) -class TrOCRForCausalLM(TrOCRPreTrainedModel): +class TrOCRForCausalLM(TrOCRPreTrainedModel, GenerationMixin): _tied_weights_keys = ["output_projection.weight"] def __init__(self, config): diff --git a/src/transformers/models/udop/modeling_udop.py b/src/transformers/models/udop/modeling_udop.py index 6f7b6cf060495a..c621b742323db2 100644 --- a/src/transformers/models/udop/modeling_udop.py +++ b/src/transformers/models/udop/modeling_udop.py @@ -34,6 +34,7 @@ ) from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_utils import PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( @@ -1679,7 +1680,7 @@ def forward( This class is based on [`T5ForConditionalGeneration`], extended to deal with images and layout (2D) data.""", UDOP_START_DOCSTRING, ) -class UdopForConditionalGeneration(UdopPreTrainedModel): +class UdopForConditionalGeneration(UdopPreTrainedModel, GenerationMixin): _tied_weights_keys = [ "encoder.embed_tokens.weight", "decoder.embed_tokens.weight", diff --git a/src/transformers/models/umt5/modeling_umt5.py b/src/transformers/models/umt5/modeling_umt5.py index 3271689540b931..a7d1e5bacc65e5 100644 --- a/src/transformers/models/umt5/modeling_umt5.py +++ b/src/transformers/models/umt5/modeling_umt5.py @@ -23,6 +23,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -1101,7 +1102,7 @@ def forward( @add_start_docstrings("""UMT5 Model with a `language modeling` head on top.""", UMT5_START_DOCSTRING) -class UMT5ForConditionalGeneration(UMT5PreTrainedModel): +class UMT5ForConditionalGeneration(UMT5PreTrainedModel, GenerationMixin): r""" Examples: diff --git a/src/transformers/models/video_llava/modeling_video_llava.py b/src/transformers/models/video_llava/modeling_video_llava.py index 9ae80be65ae4b6..7c7cfec20959ea 100644 --- a/src/transformers/models/video_llava/modeling_video_llava.py +++ b/src/transformers/models/video_llava/modeling_video_llava.py @@ -21,9 +21,10 @@ import torch.utils.checkpoint from torch import nn -from ... import PreTrainedModel from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import BaseModelOutputWithPooling, ModelOutput +from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, @@ -239,7 +240,7 @@ def _supports_sdpa(self): """The VideoLlava model which consists of a vision backbone and a language model.""", VIDEO_LLAVA_START_DOCSTRING, ) -class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel): +class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMixin): def __init__(self, config: VideoLlavaConfig): super().__init__(config) self.video_tower = AutoModel.from_config(config.vision_config) diff --git a/src/transformers/models/vipllava/modeling_vipllava.py b/src/transformers/models/vipllava/modeling_vipllava.py index 53a3213697193e..95129d46bbd8e5 100644 --- a/src/transformers/models/vipllava/modeling_vipllava.py +++ b/src/transformers/models/vipllava/modeling_vipllava.py @@ -21,9 +21,10 @@ import torch.utils.checkpoint from torch import nn -from ... import PreTrainedModel from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import ModelOutput +from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, @@ -240,7 +241,7 @@ def _supports_sdpa(self): VIPLLAVA_START_DOCSTRING, ) # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration with LLAVA->VIPLLAVA,Llava->VipLlava -class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel): +class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin): def __init__(self, config: VipLlavaConfig): super().__init__(config) self.vision_tower = AutoModel.from_config(config.vision_config) diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index 8012c3c1bbfcfb..7a4e9487288e93 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -25,7 +25,7 @@ from transformers.cache_utils import EncoderDecoderCache -from ...generation.configuration_utils import GenerationConfig +from ...generation import GenerationConfig, GenerationMixin from ...generation.logits_process import ( LogitsProcessorList, SuppressTokensAtBeginLogitsProcessor, @@ -172,7 +172,7 @@ def _pad_to_max_length( return sequences -class WhisperGenerationMixin: +class WhisperGenerationMixin(GenerationMixin): def _extract_token_timestamps(self, generate_outputs, alignment_heads, time_precision=0.02, num_frames=None): """ Calculates token-level timestamps using the encoder-decoder cross-attentions and dynamic time-warping (DTW) to diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index b82b978e5e6d95..93ec57fcf4b400 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -25,6 +25,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutput, @@ -1915,7 +1916,7 @@ def forward(self, *args, **kwargs): """, WHISPER_START_DOCSTRING, ) -class WhisperForCausalLM(WhisperPreTrainedModel): +class WhisperForCausalLM(WhisperPreTrainedModel, GenerationMixin): _tied_weights_keys = ["proj_out.weight"] main_input_name = "input_ids" diff --git a/src/transformers/models/xglm/modeling_xglm.py b/src/transformers/models/xglm/modeling_xglm.py index 4f1693583494ec..3090bc2973cd56 100755 --- a/src/transformers/models/xglm/modeling_xglm.py +++ b/src/transformers/models/xglm/modeling_xglm.py @@ -23,6 +23,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions from ...modeling_utils import PreTrainedModel @@ -696,7 +697,7 @@ def forward( """, XGLM_START_DOCSTRING, ) -class XGLMForCausalLM(XGLMPreTrainedModel): +class XGLMForCausalLM(XGLMPreTrainedModel, GenerationMixin): base_model_prefix = "model" _tied_weights_keys = ["lm_head.weight"] diff --git a/src/transformers/models/xlm/modeling_xlm.py b/src/transformers/models/xlm/modeling_xlm.py index 2803836309872c..3acec2353b69fc 100755 --- a/src/transformers/models/xlm/modeling_xlm.py +++ b/src/transformers/models/xlm/modeling_xlm.py @@ -27,6 +27,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import gelu +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutput, MaskedLMOutput, @@ -657,7 +658,7 @@ def forward(self, x, y=None): """, XLM_START_DOCSTRING, ) -class XLMWithLMHeadModel(XLMPreTrainedModel): +class XLMWithLMHeadModel(XLMPreTrainedModel, GenerationMixin): _tied_weights_keys = ["pred_layer.proj.weight"] def __init__(self, config): diff --git a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py index 2adae33fbd50a8..a153f09468937b 100644 --- a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py @@ -25,6 +25,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, gelu +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa, @@ -1006,7 +1007,7 @@ def forward( XLM_ROBERTA_START_DOCSTRING, ) # Copied from transformers.models.roberta.modeling_roberta.RobertaForCausalLM with Roberta->XLMRoberta, ROBERTA->XLM_ROBERTA -class XLMRobertaForCausalLM(XLMRobertaPreTrainedModel): +class XLMRobertaForCausalLM(XLMRobertaPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"] def __init__(self, config): diff --git a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py index f86abf823e9026..0c384ad45c523b 100644 --- a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +++ b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py @@ -24,6 +24,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, gelu +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa, @@ -986,7 +987,7 @@ def forward( """XLM-RoBERTa-XL Model with a `language modeling` head on top for CLM fine-tuning.""", XLM_ROBERTA_XL_START_DOCSTRING, ) -class XLMRobertaXLForCausalLM(XLMRobertaXLPreTrainedModel): +class XLMRobertaXLForCausalLM(XLMRobertaXLPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"] def __init__(self, config): diff --git a/src/transformers/models/xlnet/modeling_xlnet.py b/src/transformers/models/xlnet/modeling_xlnet.py index 5d424ebe12dd6f..7681fbafad6db5 100755 --- a/src/transformers/models/xlnet/modeling_xlnet.py +++ b/src/transformers/models/xlnet/modeling_xlnet.py @@ -26,6 +26,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_utils import PoolerAnswerClass, PoolerEndLogits, PoolerStartLogits, PreTrainedModel, SequenceSummary from ...pytorch_utils import apply_chunking_to_forward from ...utils import ( @@ -1286,7 +1287,7 @@ def forward( """, XLNET_START_DOCSTRING, ) -class XLNetLMHeadModel(XLNetPreTrainedModel): +class XLNetLMHeadModel(XLNetPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_loss.weight"] def __init__(self, config): diff --git a/src/transformers/models/xmod/modeling_xmod.py b/src/transformers/models/xmod/modeling_xmod.py index b1ca8116a72a78..71474cc9c45ba9 100644 --- a/src/transformers/models/xmod/modeling_xmod.py +++ b/src/transformers/models/xmod/modeling_xmod.py @@ -23,6 +23,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, gelu +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -956,7 +957,7 @@ def forward( "X-MOD Model with a `language modeling` head on top for CLM fine-tuning.", XMOD_START_DOCSTRING, ) -class XmodForCausalLM(XmodPreTrainedModel): +class XmodForCausalLM(XmodPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"] # Copied from transformers.models.roberta.modeling_roberta.RobertaForCausalLM.__init__ with Roberta->Xmod diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 2f8e60c79151e9..600942a7ac0822 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2099,6 +2099,15 @@ def test_assisted_decoding_with_num_logits_to_keep(self): ) self.assertEqual(with_all_logits.tolist(), without_all_logits.tolist()) + @pytest.mark.generate + def test_inherits_generation_mixin(self): + """ + Tests that the model class directly inherits `GenerationMixin`, as opposed to relying on `PreTrainedModel` + to inherit it. + """ + for model_class in self.all_generative_model_classes: + self.assertTrue("GenerationMixin" in str(model_class.__bases__)) + def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1): batch_size, seq_length = input_ids.shape config = config.text_config if hasattr(config, "text_config") else config diff --git a/tests/models/auto/test_modeling_auto.py b/tests/models/auto/test_modeling_auto.py index 39770b091befe8..95d716898343d0 100644 --- a/tests/models/auto/test_modeling_auto.py +++ b/tests/models/auto/test_modeling_auto.py @@ -67,6 +67,7 @@ BertModel, FunnelBaseModel, FunnelModel, + GenerationMixin, GPT2Config, GPT2LMHeadModel, ResNetBackbone, @@ -571,3 +572,20 @@ def test_dynamic_saving_from_local_repo(self): _ = AutoModelForCausalLM.from_pretrained(tmp_dir_out, trust_remote_code=True) self.assertTrue((Path(tmp_dir_out) / "modeling_fake_custom.py").is_file()) self.assertTrue((Path(tmp_dir_out) / "configuration_fake_custom.py").is_file()) + + def test_custom_model_patched_generation_inheritance(self): + """ + Tests that our inheritance patching for generate-compatible models works as expected. Without this feature, + old Hub models lose the ability to call `generate`. + """ + model = AutoModelForCausalLM.from_pretrained( + "hf-internal-testing/test_dynamic_model_generation", trust_remote_code=True + ) + self.assertTrue(model.__class__.__name__ == "NewModelForCausalLM") + + # It inherits from GenerationMixin. This means it can `generate`. Because `PreTrainedModel` is scheduled to + # stop inheriting from `GenerationMixin` in v4.50, this check will fail if patching is not present. + self.assertTrue(isinstance(model, GenerationMixin)) + # More precisely, it directly inherits from GenerationMixin. This check would fail prior to v4.45 (inheritance + # patching was added in v4.45) + self.assertTrue("GenerationMixin" in str(model.__class__.__bases__)) diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index 2130ed4b7c887f..5155647059f154 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -90,6 +90,7 @@ BertConfig, BertModel, CLIPTextModel, + GenerationMixin, PreTrainedModel, T5Config, T5ForConditionalGeneration, @@ -1715,6 +1716,32 @@ def test_isin_mps_friendly(self): torch.equal(torch.isin(random_ids, random_test_tensor), isin_mps_friendly(random_ids, random_test_tensor)) ) + def test_can_generate(self): + """Tests the behavior of `PreTrainedModel.can_generate` method.""" + # 1 - By default, a model CAN'T generate + self.assertFalse(BertModel.can_generate()) + + # 2 - The most common case for a model to be able to generate is to inherit from `GenerationMixin` directly + class DummyBertWithMixin(BertModel, GenerationMixin): + pass + + self.assertTrue(DummyBertWithMixin.can_generate()) + + # 3 - Alternatively, a model can implement a `generate` method + class DummyBertWithGenerate(BertModel): + def generate(self): + pass + + self.assertTrue(DummyBertWithGenerate.can_generate()) + + # 4 - BC: models with a custom `prepare_inputs_for_generation` can generate (it was assumed they inherited + # `GenerationMixin`) + class DummyBertWithPrepareInputs(BertModel): + def prepare_inputs_for_generation(self): + pass + + self.assertTrue(DummyBertWithPrepareInputs.can_generate()) + def test_save_and_load_config_with_custom_generation(self): """ Regression test for the ability to save and load a config with a custom generation kwarg (i.e. a parameter