Skip to content

Commit

Permalink
Pass attn_implementation when using AutoXXX.from_config (#30507)
Browse files Browse the repository at this point in the history
* Pass attn_implementation when using AutoXXX.from_config

* Fix
  • Loading branch information
amyeroberts authored Apr 29, 2024
1 parent 80126f9 commit e8acb70
Show file tree
Hide file tree
Showing 10 changed files with 41 additions and 19 deletions.
16 changes: 12 additions & 4 deletions src/transformers/models/blip_2/modeling_blip_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1194,9 +1194,13 @@ def __init__(self, config: Blip2Config):

self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size)
if config.use_decoder_only_language_model:
language_model = AutoModelForCausalLM.from_config(config.text_config)
language_model = AutoModelForCausalLM.from_config(
config.text_config, attn_implementation=config._attn_implementation
)
else:
language_model = AutoModelForSeq2SeqLM.from_config(config.text_config)
language_model = AutoModelForSeq2SeqLM.from_config(
config.text_config, attn_implementation=config._attn_implementation
)

# Update _tied_weights_keys using the base model used.
if language_model._tied_weights_keys is not None:
Expand Down Expand Up @@ -1549,9 +1553,13 @@ def __init__(self, config: Blip2Config):

self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size)
if config.use_decoder_only_language_model:
language_model = AutoModelForCausalLM.from_config(config.text_config)
language_model = AutoModelForCausalLM.from_config(
config.text_config, attn_implementation=config._attn_implementation
)
else:
language_model = AutoModelForSeq2SeqLM.from_config(config.text_config)
language_model = AutoModelForSeq2SeqLM.from_config(
config.text_config, attn_implementation=config._attn_implementation
)

# Update _tied_weights_keys using the base model used.
if language_model._tied_weights_keys is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,9 @@ class DepthAnythingForDepthEstimation(DepthAnythingPreTrainedModel):
def __init__(self, config):
super().__init__(config)

self.backbone = AutoBackbone.from_config(config.backbone_config)
self.backbone = AutoBackbone.from_config(
config.backbone_config, attn_implementation=config._attn_implementation
)
self.neck = DepthAnythingNeck(config)
self.head = DepthAnythingDepthEstimationHead(config)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,12 +209,12 @@ def __init__(
if encoder is None:
from ..auto.modeling_auto import AutoModel

encoder = AutoModel.from_config(config.encoder)
encoder = AutoModel.from_config(config.encoder, attn_implementation=config._attn_implementation)

if decoder is None:
from ..auto.modeling_auto import AutoModelForCausalLM

decoder = AutoModelForCausalLM.from_config(config.decoder)
decoder = AutoModelForCausalLM.from_config(config.decoder, attn_implementation=config._attn_implementation)

self.encoder = encoder
self.decoder = decoder
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/models/fuyu/modeling_fuyu.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,9 @@ def __init__(self, config: FuyuConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.language_model = AutoModelForCausalLM.from_config(config.text_config)
self.language_model = AutoModelForCausalLM.from_config(
config.text_config, attn_implementation=config._attn_implementation
)

self.vision_embed_tokens = nn.Linear(
config.patch_size * config.patch_size * config.num_channels, config.hidden_size
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/idefics2/modeling_idefics2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1476,7 +1476,7 @@ def __init__(self, config: Idefics2Config):

self.vision_model = Idefics2VisionTransformer(config.vision_config)
self.connector = Idefics2Connector(config)
self.text_model = AutoModel.from_config(config.text_config)
self.text_model = AutoModel.from_config(config.text_config, attn_implementation=config._attn_implementation)

self.image_seq_len = config.perceiver_config.resampler_n_latents
self.image_token_id = self.config.image_token_id
Expand Down
8 changes: 6 additions & 2 deletions src/transformers/models/instructblip/modeling_instructblip.py
Original file line number Diff line number Diff line change
Expand Up @@ -1251,9 +1251,13 @@ def __init__(self, config: InstructBlipConfig):
self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size)

if config.use_decoder_only_language_model:
language_model = AutoModelForCausalLM.from_config(config.text_config)
language_model = AutoModelForCausalLM.from_config(
config.text_config, attn_implementation=config._attn_implementation
)
else:
language_model = AutoModelForSeq2SeqLM.from_config(config.text_config)
language_model = AutoModelForSeq2SeqLM.from_config(
config.text_config, attn_implementation=config._attn_implementation
)

if language_model._no_split_modules is not None:
self._no_split_modules.extend(language_model._no_split_modules)
Expand Down
8 changes: 6 additions & 2 deletions src/transformers/models/rag/modeling_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,12 +506,16 @@ def __init__(
if question_encoder is None:
from ..auto.modeling_auto import AutoModel

question_encoder = AutoModel.from_config(config.question_encoder)
question_encoder = AutoModel.from_config(
config.question_encoder, attn_implementation=config._attn_implementation
)

if generator is None:
from ..auto.modeling_auto import AutoModelForSeq2SeqLM

generator = AutoModelForSeq2SeqLM.from_config(config.generator)
generator = AutoModelForSeq2SeqLM.from_config(
config.generator, attn_implementation=config._attn_implementation
)

self.retriever = retriever
if self.retriever is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,10 +212,10 @@ def __init__(
super().__init__(config)

if encoder is None:
encoder = AutoModel.from_config(config.encoder)
encoder = AutoModel.from_config(config.encoder, attn_implementation=config._attn_implementation)

if decoder is None:
decoder = AutoModelForCausalLM.from_config(config.decoder)
decoder = AutoModelForCausalLM.from_config(config.decoder, attn_implementation=config._attn_implementation)

self.encoder = encoder
self.decoder = decoder
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,10 +190,10 @@ def __init__(
super().__init__(config)

if encoder is None:
encoder = AutoModel.from_config(config.encoder)
encoder = AutoModel.from_config(config.encoder, attn_implementation=config._attn_implementation)

if decoder is None:
decoder = AutoModelForCausalLM.from_config(config.decoder)
decoder = AutoModelForCausalLM.from_config(config.decoder, attn_implementation=config._attn_implementation)

self.encoder = encoder
self.decoder = decoder
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,10 +185,12 @@ def __init__(
if isinstance(config.vision_config, CLIPVisionConfig):
vision_model = CLIPVisionModel(config.vision_config)
else:
vision_model = AutoModel.from_config(config.vision_config)
vision_model = AutoModel.from_config(
config.vision_config, attn_implementation=config._attn_implementation
)

if text_model is None:
text_model = AutoModel.from_config(config.text_config)
text_model = AutoModel.from_config(config.text_config, attn_implementation=config._attn_implementation)

self.vision_model = vision_model
self.text_model = text_model
Expand Down

0 comments on commit e8acb70

Please sign in to comment.