Skip to content

Commit

Permalink
Fix normalized config key for models architecture
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Sep 22, 2023
1 parent 6d1ae0e commit d6f8e10
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions optimum/utils/normalized_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ class NormalizedConfigManager:
# "big_bird": NormalizedTextConfig,
# "bigbird_pegasus": BartLikeNormalizedTextConfig,
"blenderbot": BartLikeNormalizedTextConfig,
"blenderbot_small": BartLikeNormalizedTextConfig,
"blenderbot-small": BartLikeNormalizedTextConfig,
"bloom": NormalizedTextConfig.with_args(num_layers="n_layer"),
"falcon": NormalizedTextConfig.with_args(num_layers="num_hidden_layers", num_attention_heads="num_kv_heads"),
"camembert": NormalizedTextConfig,
Expand All @@ -224,16 +224,16 @@ class NormalizedConfigManager:
"encoder-decoder": NormalizedEncoderDecoderConfig,
"gpt2": GPT2LikeNormalizedTextConfig,
"gpt-bigcode": GPT2LikeNormalizedTextConfig,
"gpt_neo": NormalizedTextConfig.with_args(num_attention_heads="num_heads"),
"gpt_neox": NormalizedTextConfig,
"gpt-neo": NormalizedTextConfig.with_args(num_attention_heads="num_heads"),
"gpt-neox": NormalizedTextConfig,
"llama": NormalizedTextConfig,
"gptj": GPT2LikeNormalizedTextConfig,
"imagegpt": GPT2LikeNormalizedTextConfig,
"longt5": T5LikeNormalizedTextConfig,
"marian": BartLikeNormalizedTextConfig,
"mbart": BartLikeNormalizedTextConfig,
"mt5": T5LikeNormalizedTextConfig,
"m2m_100": BartLikeNormalizedTextConfig,
"m2m-100": BartLikeNormalizedTextConfig,
"nystromformer": NormalizedTextConfig,
"opt": NormalizedTextConfig,
"pegasus": BartLikeNormalizedTextConfig,
Expand All @@ -242,7 +242,7 @@ class NormalizedConfigManager:
"regnet": NormalizedVisionConfig,
"resnet": NormalizedVisionConfig,
"roberta": NormalizedTextConfig,
"speech_to_text": SpeechToTextLikeNormalizedTextConfig,
"speech-to-text": SpeechToTextLikeNormalizedTextConfig,
"splinter": NormalizedTextConfig,
"t5": T5LikeNormalizedTextConfig,
"trocr": TrOCRLikeNormalizedTextConfig,
Expand All @@ -252,7 +252,7 @@ class NormalizedConfigManager:
"xlm-roberta": NormalizedTextConfig,
"yolos": NormalizedVisionConfig,
"mpt": MPTNormalizedTextConfig,
"gpt_bigcode": GPTBigCodeNormalizedTextConfig,
"gpt-bigcode": GPTBigCodeNormalizedTextConfig,
}

@classmethod
Expand All @@ -266,5 +266,6 @@ def check_supported_model(cls, model_type: str):

@classmethod
def get_normalized_config_class(cls, model_type: str) -> Type:
model_type = model_type.replace("_", "-")
cls.check_supported_model(model_type)
return cls._conf[model_type]

0 comments on commit d6f8e10

Please sign in to comment.