Skip to content

Commit

Permalink
latest transformers changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Cyrilvallez committed Dec 19, 2024
1 parent f4c60ca commit 2e2631e
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 14 deletions.
14 changes: 8 additions & 6 deletions server/text_generation_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from loguru import logger
from transformers.configuration_utils import PretrainedConfig
from transformers.models.auto import modeling_auto, modeling_task
from transformers.models.auto import modeling_auto
from huggingface_hub import hf_hub_download, HfApi
from typing import Optional, List, Dict
from pathlib import Path
Expand Down Expand Up @@ -380,12 +380,14 @@ def get_model(
logger.info(
"TGI's flash enabled models could either not be loaded or are disabled, using Transformers fallback."
)
try:
transformers_model_class = getattr(transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type])
except KeyError:
transformers_model_class = modeling_task.AutoForCausalLM

if transformers_model_class._supports_flash_attn_2:
transformers_model_class = getattr(transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type])
# Ugly check but works in the meantime
model_path = os.path.join(os.path.dirname(transformers.__file__), "models", model_type, f"modeling_{model_type}.py")
with open(model_path) as file:
has_fa2_class = f"FlashAttention2(" in file.read()

if transformers_model_class._supports_flash_attn_2 and not has_fa2_class:
logger.info(
f"Transformers' {model_type} implementation supports ragged tensors format (single dimension for "
"batch and sequence length). All TGI's batching/caching optimizations are enabled."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,6 @@ def tgi_flash_attention_forward(
key_states = key_states.transpose(1, 2).squeeze(dim=0)
value_states = value_states.transpose(1, 2).squeeze(dim=0)

input_dtype = query_states.dtype
if input_dtype == torch.float32:
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)

# Take care of updating the cache in-place
kv_cache.store(
key=key_states,
Expand All @@ -66,7 +60,6 @@ def tgi_flash_attention_forward(
kv_scales=kv_scales
)


_, num_heads, head_dim = query_states.shape
softmax_scale = 1 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale
sliding_window = -1 if sliding_window is None else sliding_window
Expand Down Expand Up @@ -155,7 +148,8 @@ def __init__(
device_map=("auto" if device_count > 1 else None),
load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code,
attn_implementation="tgi"
attn_implementation="tgi",
tp_plan="auto" if world_size > 1 else None,
)

if device_count == 1 and quantize != "bitsandbytes":
Expand Down

0 comments on commit 2e2631e

Please sign in to comment.