diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 065b22f2ab1..e9e8680eacc 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -731,13 +731,11 @@ def __init__( } - if model.config.model_type in ["llama", "mistral", "starcoder2", "qwen2", "falcon", "gemma"]: - + if model.config.model_type in ["llama", "mistral", "starcoder2", "qwen2", "falcon"]: if model.config.model_type not in ["falcon"]: self.kwargs["attn_softmax_bf16"] = True - if model.config.model_type not in ["gemma"]: - self.kwargs["trim_logits"] = True + self.kwargs["trim_logits"] = True if os.getenv("USE_FLASH_ATTENTION", "false").lower() == "true": self.kwargs["use_flash_attention"] = True