diff --git a/src/transformers/models/musicgen/convert_musicgen_transformers.py b/src/transformers/models/musicgen/convert_musicgen_transformers.py index f1eb9e40704dfe..a072ec321b73c8 100644 --- a/src/transformers/models/musicgen/convert_musicgen_transformers.py +++ b/src/transformers/models/musicgen/convert_musicgen_transformers.py @@ -88,24 +88,24 @@ def rename_state_dict(state_dict: OrderedDict, hidden_size: int) -> Tuple[Dict, def decoder_config_from_checkpoint(checkpoint: str) -> MusicgenDecoderConfig: - if checkpoint == "small" or checkpoint == "facebook/musicgen-stereo-small": + if checkpoint.endswith("small"): # default config values hidden_size = 1024 num_hidden_layers = 24 num_attention_heads = 16 - elif checkpoint == "medium" or checkpoint == "facebook/musicgen-stereo-medium": + elif checkpoint.endswith("medium"): hidden_size = 1536 num_hidden_layers = 48 num_attention_heads = 24 - elif checkpoint == "large" or checkpoint == "facebook/musicgen-stereo-large": + elif checkpoint.endswith("large"): hidden_size = 2048 num_hidden_layers = 48 num_attention_heads = 32 else: raise ValueError( "Checkpoint should be one of `['small', 'medium', 'large']` for the mono checkpoints, " - "or `['facebook/musicgen-stereo-small', 'facebook/musicgen-stereo-medium', 'facebook/musicgen-stereo-large']` " - f"for the stereo checkpoints, got {checkpoint}." + "`['facebook/musicgen-stereo-small', 'facebook/musicgen-stereo-medium', 'facebook/musicgen-stereo-large']` " + f"for the stereo checkpoints, or a custom checkpoint with the checkpoint size as a suffix, got {checkpoint}." ) if "stereo" in checkpoint: @@ -208,9 +208,9 @@ def convert_musicgen_checkpoint( default="small", type=str, help="Checkpoint size of the MusicGen model you'd like to convert. Can be one of: " - "`['small', 'medium', 'large']` for the mono checkpoints, or " + "`['small', 'medium', 'large']` for the mono checkpoints, " "`['facebook/musicgen-stereo-small', 'facebook/musicgen-stereo-medium', 'facebook/musicgen-stereo-large']` " - "for the stereo checkpoints.", + "for the stereo checkpoints, or a custom checkpoint with the checkpoint size as a suffix.", ) parser.add_argument( "--pytorch_dump_folder",