diff --git a/hqq/models/hf/base.py b/hqq/models/hf/base.py index f8231ab..68fab2f 100755 --- a/hqq/models/hf/base.py +++ b/hqq/models/hf/base.py @@ -13,8 +13,15 @@ def cache_model(cls, model, save_dir): # Create empty model from config @classmethod - def create_model(cls, save_dir): - config = transformers.AutoConfig.from_pretrained(cls.get_config_file(save_dir)) + def create_model(cls, save_dir, kwargs): + config_kwargs = {} + for key in ["attn_implementation"]: + if key in kwargs: + config_kwargs[key] = kwargs[key] + + config = transformers.AutoConfig.from_pretrained( + cls.get_config_file(save_dir), **config_kwargs + ) auto_class = transformers.AutoModel