Skip to content

Commit

Permalink
Autoset attn_implementation in config
Browse files Browse the repository at this point in the history
  • Loading branch information
qubvel committed Jun 24, 2024
1 parent a12367b commit 9d09374
Showing 1 changed file with 25 additions and 27 deletions.
52 changes: 25 additions & 27 deletions src/transformers/models/siglip/modeling_siglip.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,33 +825,6 @@ def _init_weights(self, module):
module.bias.data.zero_()
module.weight.data.fill_(1.0)

@classmethod
def _autoset_attn_implementation(
cls,
config,
use_flash_attention_2: bool = False,
torch_dtype: Optional[torch.dtype] = None,
device_map: Optional[Union[str, Dict[str, int]]] = None,
check_device_map: bool = True,
**kwargs,
):
"""
Overrides the method in `PreTrainedModel` to update the vision config with the correct attention implementation
"""
config = super()._autoset_attn_implementation(
config=config,
use_flash_attention_2=use_flash_attention_2,
torch_dtype=torch_dtype,
device_map=device_map,
check_device_map=check_device_map,
**kwargs,
)
if hasattr(config, "vision_config"):
config.vision_config._attn_implementation = config._attn_implementation
if hasattr(config, "text_config"):
config.text_config._attn_implementation = config._attn_implementation
return config


SIGLIP_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
Expand Down Expand Up @@ -1356,6 +1329,31 @@ def __init__(self, config: SiglipConfig):
# Initialize weights and apply final processing
self.post_init()

@classmethod
def _autoset_attn_implementation(
cls,
config,
use_flash_attention_2: bool = False,
torch_dtype: Optional[torch.dtype] = None,
device_map: Optional[Union[str, Dict[str, int]]] = None,
check_device_map: bool = True,
**kwargs,
):
"""
Overrides the method in `PreTrainedModel` to update the vision config with the correct attention implementation
"""
config = super()._autoset_attn_implementation(
config=config,
use_flash_attention_2=use_flash_attention_2,
torch_dtype=torch_dtype,
device_map=device_map,
check_device_map=check_device_map,
**kwargs,
)
config.vision_config._attn_implementation = config._attn_implementation
config.text_config._attn_implementation = config._attn_implementation
return config

@add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING)
def get_text_features(
self,
Expand Down

0 comments on commit 9d09374

Please sign in to comment.