Skip to content

Commit

Permalink
Add support for custom checkpoints in MusicGen (#30011)
Browse files Browse the repository at this point in the history
* feat: support custom checkpoint

* update: revert changes and add TODO

* update: docs and exception handling

* fix: ah, extra space
  • Loading branch information
jla524 authored May 15, 2024
1 parent 1360801 commit 99e1612
Showing 1 changed file with 7 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -88,24 +88,24 @@ def rename_state_dict(state_dict: OrderedDict, hidden_size: int) -> Tuple[Dict,


def decoder_config_from_checkpoint(checkpoint: str) -> MusicgenDecoderConfig:
if checkpoint == "small" or checkpoint == "facebook/musicgen-stereo-small":
if checkpoint.endswith("small"):
# default config values
hidden_size = 1024
num_hidden_layers = 24
num_attention_heads = 16
elif checkpoint == "medium" or checkpoint == "facebook/musicgen-stereo-medium":
elif checkpoint.endswith("medium"):
hidden_size = 1536
num_hidden_layers = 48
num_attention_heads = 24
elif checkpoint == "large" or checkpoint == "facebook/musicgen-stereo-large":
elif checkpoint.endswith("large"):
hidden_size = 2048
num_hidden_layers = 48
num_attention_heads = 32
else:
raise ValueError(
"Checkpoint should be one of `['small', 'medium', 'large']` for the mono checkpoints, "
"or `['facebook/musicgen-stereo-small', 'facebook/musicgen-stereo-medium', 'facebook/musicgen-stereo-large']` "
f"for the stereo checkpoints, got {checkpoint}."
"`['facebook/musicgen-stereo-small', 'facebook/musicgen-stereo-medium', 'facebook/musicgen-stereo-large']` "
f"for the stereo checkpoints, or a custom checkpoint with the checkpoint size as a suffix, got {checkpoint}."
)

if "stereo" in checkpoint:
Expand Down Expand Up @@ -208,9 +208,9 @@ def convert_musicgen_checkpoint(
default="small",
type=str,
help="Checkpoint size of the MusicGen model you'd like to convert. Can be one of: "
"`['small', 'medium', 'large']` for the mono checkpoints, or "
"`['small', 'medium', 'large']` for the mono checkpoints, "
"`['facebook/musicgen-stereo-small', 'facebook/musicgen-stereo-medium', 'facebook/musicgen-stereo-large']` "
"for the stereo checkpoints.",
"for the stereo checkpoints, or a custom checkpoint with the checkpoint size as a suffix.",
)
parser.add_argument(
"--pytorch_dump_folder",
Expand Down

0 comments on commit 99e1612

Please sign in to comment.