diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index aa43107f161..7f8268a93be 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -768,7 +768,10 @@ def init_kv_cache( empty_cache() element_size = torch.tensor([], dtype=dtype).element_size() - x = BLOCK_SIZE // element_size + if SYSTEM == "ipex" and device.type == "xpu": + x = 1 + else: + x = BLOCK_SIZE // element_size if SYSTEM == "ipex" and device == torch.device("cpu"): self.kv_cache = [ diff --git a/server/text_generation_server/models/flash_gpt2.py b/server/text_generation_server/models/flash_gpt2.py index 75c7203ac02..2d0f9fcc513 100644 --- a/server/text_generation_server/models/flash_gpt2.py +++ b/server/text_generation_server/models/flash_gpt2.py @@ -37,9 +37,10 @@ def __init__( elif SYSTEM == "ipex": if hasattr(torch, "xpu") and torch.xpu.is_available(): device = torch.device(f"xpu:{rank}") + dtype = torch.float16 if dtype is None else dtype else: device = torch.device("cpu") - dtype = torch.float16 if dtype is None else dtype + dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashGPT2 is only available on GPU") diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index 76c522e392b..9366706fd00 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -37,9 +37,10 @@ def __init__( elif SYSTEM == "ipex": if hasattr(torch, "xpu") and torch.xpu.is_available(): device = torch.device(f"xpu:{rank}") + dtype = torch.float16 if dtype is None else dtype else: device = torch.device("cpu") - dtype = torch.float16 if dtype is None else dtype + dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashLlama is only available on GPU") diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index 78a09cf5780..16778ada6a6 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -41,9 +41,10 @@ def __init__( elif SYSTEM == "ipex": if hasattr(torch, "xpu") and torch.xpu.is_available(): device = torch.device(f"xpu:{rank}") + dtype = torch.float16 if dtype is None else dtype else: device = torch.device("cpu") - dtype = torch.float16 if dtype is None else dtype + dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashMistral is only available on GPU") diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index 9c82bf523e8..87ae570c98f 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -36,9 +36,10 @@ def __init__( elif SYSTEM == "ipex": if hasattr(torch, "xpu") and torch.xpu.is_available(): device = torch.device(f"xpu:{rank}") + dtype = torch.float16 if dtype is None else dtype else: device = torch.device("cpu") - dtype = torch.float16 if dtype is None else dtype + dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashNeoX is only available on GPU") diff --git a/server/text_generation_server/models/flash_rw.py b/server/text_generation_server/models/flash_rw.py index e8087f230a0..6ed1f6f740b 100644 --- a/server/text_generation_server/models/flash_rw.py +++ b/server/text_generation_server/models/flash_rw.py @@ -37,9 +37,10 @@ def __init__( elif SYSTEM == "ipex": if hasattr(torch, "xpu") and torch.xpu.is_available(): device = torch.device(f"xpu:{rank}") + dtype = torch.float16 if dtype is None else dtype else: device = torch.device("cpu") - dtype = torch.float16 if dtype is None else dtype + dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashRW is only available on GPU") diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index 83a6b92c97e..ab1e451669e 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -40,9 +40,10 @@ def __init__( elif SYSTEM == "ipex": if hasattr(torch, "xpu") and torch.xpu.is_available(): device = torch.device(f"xpu:{rank}") + dtype = torch.float16 if dtype is None else dtype else: device = torch.device("cpu") - dtype = torch.float16 if dtype is None else dtype + dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashSantacoderSharded is only available on GPU")