Skip to content

Commit

Permalink
Generation: deprecate PreTrainedModel inheriting from `GenerationMi…
Browse files Browse the repository at this point in the history
…xin` (#33203)
  • Loading branch information
gante authored Sep 23, 2024
1 parent 1456120 commit e15687f
Show file tree
Hide file tree
Showing 126 changed files with 407 additions and 184 deletions.
34 changes: 11 additions & 23 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,6 @@
)
from ..integrations.deepspeed import is_deepspeed_zero3_enabled
from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput
from ..models.auto import (
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING,
MODEL_FOR_CAUSAL_LM_MAPPING,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
MODEL_FOR_VISION_2_SEQ_MAPPING,
)
from ..pytorch_utils import isin_mps_friendly
from ..tokenization_utils import ExtensionsTrie
from ..utils import (
Expand Down Expand Up @@ -1117,26 +1110,21 @@ def _validate_model_class(self):
Confirms that the model class is compatible with generation. If not, raises an exception that points to the
right class to use.
"""
# TODO(joao): remove this function in v4.50, i.e. when we remove the inheritance of `GenerationMixin` from
# `PreTrainedModel`. With that inheritance removed, all model classes inheriting from `GenerationMixin` can
# safely call `GenerationMixin.generate`
if not is_torchdynamo_compiling() and not self.can_generate():
generate_compatible_mappings = [
MODEL_FOR_CAUSAL_LM_MAPPING,
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING,
MODEL_FOR_VISION_2_SEQ_MAPPING,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
terminations_with_generation_support = [
"ForCausalLM",
"ForConditionalGeneration",
"ForSpeechSeq2Seq",
"ForVision2Seq",
]
generate_compatible_classes = set()
for model_mapping in generate_compatible_mappings:
supported_models = model_mapping.get(type(self.config), default=None)
if supported_models is not None:
generate_compatible_classes.add(supported_models.__name__)
exception_message = (
raise TypeError(
f"The current model class ({self.__class__.__name__}) is not compatible with `.generate()`, as "
"it doesn't have a language model head."
"it doesn't have a language model head. Classes that support generation often end in one of these "
f"names: {terminations_with_generation_support}."
)
if generate_compatible_classes:
exception_message += f" Please use one of the following classes instead: {generate_compatible_classes}"
raise TypeError(exception_message)

def _validate_assistant(self, assistant_model):
if assistant_model is None:
Expand Down
36 changes: 28 additions & 8 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def _skip_init(*args, **kwargs):
setattr(torch.nn.init, name, init_func)


def get_parameter_device(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]):
def get_parameter_device(parameter: Union[nn.Module, "ModuleUtilsMixin"]):
try:
return next(parameter.parameters()).device
except StopIteration:
Expand All @@ -227,7 +227,7 @@ def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
return first_tuple[1].device


def get_first_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]):
def get_first_parameter_dtype(parameter: Union[nn.Module, "ModuleUtilsMixin"]):
"""
Returns the first parameter dtype (can be non-floating) or asserts if none were found.
"""
Expand All @@ -245,7 +245,7 @@ def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
return first_tuple[1].dtype


def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]):
def get_parameter_dtype(parameter: Union[nn.Module, "ModuleUtilsMixin"]):
"""
Returns the first found floating dtype in parameters if there is one, otherwise returns the last dtype it found.
"""
Expand Down Expand Up @@ -1309,6 +1309,7 @@ def floating_point_ops(
return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings)


# TODO (joao): remove `GenerationMixin` inheritance in v4.50
class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin, PeftAdapterMixin):
r"""
Base class for all models.
Expand Down Expand Up @@ -1638,11 +1639,30 @@ def can_generate(cls) -> bool:
Returns:
`bool`: Whether this model can generate sequences with `.generate()`.
"""
# Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation.
# Alternativelly, the model can also have a custom `generate` function.
if "GenerationMixin" in str(cls.prepare_inputs_for_generation) and "GenerationMixin" in str(cls.generate):
return False
return True
# Directly inherits `GenerationMixin` -> can generate
if "GenerationMixin" in str(cls.__bases__):
return True
# Model class overwrites `generate` (e.g. time series models) -> can generate
if str(cls.__name__) in str(cls.generate):
return True
# BC: Detects whether `prepare_inputs_for_generation` has been overwritten in the model. Prior to v4.45, this
# was how we detected whether a model could generate.
if "GenerationMixin" not in str(cls.prepare_inputs_for_generation):
logger.warning_once(
f"{cls.__name__} has generative capabilities, as `prepare_inputs_for_generation` is explicitly "
"overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, "
"`PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability "
"to call `generate` and other related functions."
"\n - If you're using `trust_remote_code=True`, you can get rid of this warning by loading the "
"model with an auto class. See https://huggingface.co/docs/transformers/en/model_doc/auto#auto-classes"
"\n - If you are the owner of the model architecture code, please modify your model class such that "
"it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception)."
"\n - If you are not the owner of the model architecture class, please contact the model code owner "
"to update it."
)
return True
# Otherwise, can't generate
return False

@classmethod
def _check_and_enable_flash_attn_2(
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/albert/modeling_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa
from ...modeling_outputs import (
BaseModelOutput,
Expand Down Expand Up @@ -983,7 +984,7 @@ def forward(self, pooled_output: torch.Tensor) -> torch.Tensor:
"Albert Model with a `language modeling` head on top.",
ALBERT_START_DOCSTRING,
)
class AlbertForMaskedLM(AlbertPreTrainedModel):
class AlbertForMaskedLM(AlbertPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["predictions.decoder.bias", "predictions.decoder.weight"]

def __init__(self, config):
Expand Down
35 changes: 35 additions & 0 deletions src/transformers/models/auto/auto_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,17 @@
extract_commit_hash,
find_adapter_config_file,
is_peft_available,
is_torch_available,
logging,
requires_backends,
)
from .configuration_auto import AutoConfig, model_type_to_module_name, replace_list_option_in_docstrings


if is_torch_available():
from ...generation import GenerationMixin


logger = logging.get_logger(__name__)


Expand Down Expand Up @@ -428,6 +433,7 @@ def from_config(cls, config, **kwargs):
model_class = get_class_from_dynamic_module(class_ref, repo_id, **kwargs)
cls.register(config.__class__, model_class, exist_ok=True)
_ = kwargs.pop("code_revision", None)
model_class = add_generation_mixin_to_remote_model(model_class)
return model_class._from_config(config, **kwargs)
elif type(config) in cls._model_mapping.keys():
model_class = _get_model_class(config, cls._model_mapping)
Expand Down Expand Up @@ -549,6 +555,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
)
_ = hub_kwargs.pop("code_revision", None)
cls.register(config.__class__, model_class, exist_ok=True)
model_class = add_generation_mixin_to_remote_model(model_class)
return model_class.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
)
Expand Down Expand Up @@ -698,6 +705,34 @@ def getattribute_from_module(module, attr):
raise ValueError(f"Could not find {attr} in {transformers_module}!")


def add_generation_mixin_to_remote_model(model_class):
"""
Adds `GenerationMixin` to the inheritance of `model_class`, if `model_class` is a PyTorch model.
This function is used for backwards compatibility purposes: in v4.45, we've started a deprecation cycle to make
`PreTrainedModel` stop inheriting from `GenerationMixin`. Without this function, older models dynamically loaded
from the Hub may not have the `generate` method after we remove the inheritance.
"""
# 1. If it is not a PT model (i.e. doesn't inherit Module), do nothing
if "torch.nn.modules.module.Module" not in str(model_class.__mro__):
return model_class

# 2. If it already **directly** inherits from GenerationMixin, do nothing
if "GenerationMixin" in str(model_class.__bases__):
return model_class

# 3. Prior to v4.45, we could detect whether a model was `generate`-compatible if it had its own `generate` and/or
# `prepare_inputs_for_generation` method.
has_custom_generate = "GenerationMixin" not in str(getattr(model_class, "generate"))
has_custom_prepare_inputs = "GenerationMixin" not in str(getattr(model_class, "prepare_inputs_for_generation"))
if has_custom_generate or has_custom_prepare_inputs:
model_class_with_generation_mixin = type(
model_class.__name__, (model_class, GenerationMixin), {**model_class.__dict__}
)
return model_class_with_generation_mixin
return model_class


class _LazyAutoMapping(OrderedDict):
"""
" A mapping config to object (model or tokenizer for instance) that will load keys and values when it is accessed.
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/bark/modeling_bark.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from torch import nn
from torch.nn import functional as F

from ...generation import GenerationMixin
from ...generation.logits_process import (
AlternatingCodebooksLogitsProcessor,
BarkEosPrioritizerLogitsProcessor,
Expand Down Expand Up @@ -546,7 +547,7 @@ def device(self) -> torch.device:


# GPT2-like autoregressive model
class BarkCausalModel(BarkPreTrainedModel):
class BarkCausalModel(BarkPreTrainedModel, GenerationMixin):
config_class = BarkSubModelConfig

def __init__(self, config):
Expand Down
5 changes: 3 additions & 2 deletions src/transformers/models/bart/modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import (
_prepare_4d_attention_mask,
_prepare_4d_attention_mask_for_sdpa,
Expand Down Expand Up @@ -1557,7 +1558,7 @@ def forward(
@add_start_docstrings(
"The BART Model with a language modeling head. Can be used for summarization.", BART_START_DOCSTRING
)
class BartForConditionalGeneration(BartPreTrainedModel):
class BartForConditionalGeneration(BartPreTrainedModel, GenerationMixin):
base_model_prefix = "model"
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
_keys_to_ignore_on_load_missing = ["final_logits_bias"]
Expand Down Expand Up @@ -2010,7 +2011,7 @@ def forward(self, *args, **kwargs):
""",
BART_START_DOCSTRING,
)
class BartForCausalLM(BartPreTrainedModel):
class BartForCausalLM(BartPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]

def __init__(self, config):
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/bert/modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import (
_prepare_4d_attention_mask_for_sdpa,
_prepare_4d_causal_attention_mask_for_sdpa,
Expand Down Expand Up @@ -1280,7 +1281,7 @@ def forward(
@add_start_docstrings(
"""Bert Model with a `language modeling` head on top for CLM fine-tuning.""", BERT_START_DOCSTRING
)
class BertLMHeadModel(BertPreTrainedModel):
class BertLMHeadModel(BertPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]

def __init__(self, config):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from torch.nn import CrossEntropyLoss

from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
Expand Down Expand Up @@ -863,7 +864,7 @@ def _tie_weights(self):
"""BertGeneration Model with a `language modeling` head on top for CLM fine-tuning.""",
BERT_GENERATION_START_DOCSTRING,
)
class BertGenerationDecoder(BertGenerationPreTrainedModel):
class BertGenerationDecoder(BertGenerationPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]

def __init__(self, config):
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/big_bird/modeling_big_bird.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions,
Expand Down Expand Up @@ -2495,7 +2496,7 @@ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_
@add_start_docstrings(
"""BigBird Model with a `language modeling` head on top for CLM fine-tuning.""", BIG_BIRD_START_DOCSTRING
)
class BigBirdForCausalLM(BigBirdPreTrainedModel):
class BigBirdForCausalLM(BigBirdPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]

def __init__(self, config):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
from ...modeling_outputs import (
BaseModelOutput,
Expand Down Expand Up @@ -2436,7 +2437,7 @@ def forward(
BIGBIRD_PEGASUS_START_DOCSTRING,
)
# Copied from transformers.models.bart.modeling_bart.BartForConditionalGeneration with Bart->BigBirdPegasus, BART->BIGBIRD_PEGASUS
class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel):
class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel, GenerationMixin):
base_model_prefix = "model"
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
_keys_to_ignore_on_load_missing = ["final_logits_bias"]
Expand Down Expand Up @@ -2882,7 +2883,7 @@ def forward(self, *args, **kwargs):
return self.decoder(*args, **kwargs)


class BigBirdPegasusForCausalLM(BigBirdPegasusPreTrainedModel):
class BigBirdPegasusForCausalLM(BigBirdPegasusPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]

def __init__(self, config):
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/biogpt/modeling_biogpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
Expand Down Expand Up @@ -719,7 +720,7 @@ def forward(
@add_start_docstrings(
"""BioGPT Model with a `language modeling` head on top for CLM fine-tuning.""", BIOGPT_START_DOCSTRING
)
class BioGptForCausalLM(BioGptPreTrainedModel):
class BioGptForCausalLM(BioGptPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["output_projection.weight"]

def __init__(self, config):
Expand Down
5 changes: 3 additions & 2 deletions src/transformers/models/blenderbot/modeling_blenderbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from torch.nn import CrossEntropyLoss

from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
from ...modeling_outputs import (
BaseModelOutput,
Expand Down Expand Up @@ -1196,7 +1197,7 @@ def forward(
@add_start_docstrings(
"The Blenderbot Model with a language modeling head. Can be used for summarization.", BLENDERBOT_START_DOCSTRING
)
class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel):
class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel, GenerationMixin):
base_model_prefix = "model"
_keys_to_ignore_on_load_missing = ["final_logits_bias"]
_tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight", "lm_head.weight"]
Expand Down Expand Up @@ -1397,7 +1398,7 @@ def forward(self, *args, **kwargs):


# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->Blenderbot, facebook/bart-base->facebook/blenderbot-400M-distill
class BlenderbotForCausalLM(BlenderbotPreTrainedModel):
class BlenderbotForCausalLM(BlenderbotPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]

def __init__(self, config):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from torch.nn import CrossEntropyLoss

from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
from ...modeling_outputs import (
BaseModelOutput,
Expand Down Expand Up @@ -1163,7 +1164,7 @@ def forward(
"The BlenderbotSmall Model with a language modeling head. Can be used for summarization.",
BLENDERBOT_SMALL_START_DOCSTRING,
)
class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel):
class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel, GenerationMixin):
base_model_prefix = "model"
_keys_to_ignore_on_load_missing = ["final_logits_bias"]
_tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight", "lm_head.weight"]
Expand Down Expand Up @@ -1349,7 +1350,7 @@ def forward(self, *args, **kwargs):


# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->BlenderbotSmall, facebook/bart-base->facebook/blenderbot_small-90M
class BlenderbotSmallForCausalLM(BlenderbotSmallPreTrainedModel):
class BlenderbotSmallForCausalLM(BlenderbotSmallPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]

def __init__(self, config):
Expand Down
Loading

0 comments on commit e15687f

Please sign in to comment.