Skip to content

Commit

Permalink
fix cpu and xpu issue (#2116)
Browse files Browse the repository at this point in the history
Signed-off-by: Wang, Yi A <[email protected]>
  • Loading branch information
sywangyi authored Jun 25, 2024
1 parent 9e2fdf5 commit e563983
Show file tree
Hide file tree
Showing 7 changed files with 16 additions and 7 deletions.
5 changes: 4 additions & 1 deletion server/text_generation_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,7 +768,10 @@ def init_kv_cache(
empty_cache()

element_size = torch.tensor([], dtype=dtype).element_size()
x = BLOCK_SIZE // element_size
if SYSTEM == "ipex" and device.type == "xpu":
x = 1
else:
x = BLOCK_SIZE // element_size

if SYSTEM == "ipex" and device == torch.device("cpu"):
self.kv_cache = [
Expand Down
3 changes: 2 additions & 1 deletion server/text_generation_server/models/flash_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,10 @@ def __init__(
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.float16 if dtype is None else dtype
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError("FlashGPT2 is only available on GPU")

Expand Down
3 changes: 2 additions & 1 deletion server/text_generation_server/models/flash_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,10 @@ def __init__(
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.float16 if dtype is None else dtype
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError("FlashLlama is only available on GPU")

Expand Down
3 changes: 2 additions & 1 deletion server/text_generation_server/models/flash_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,10 @@ def __init__(
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.float16 if dtype is None else dtype
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError("FlashMistral is only available on GPU")

Expand Down
3 changes: 2 additions & 1 deletion server/text_generation_server/models/flash_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,10 @@ def __init__(
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.float16 if dtype is None else dtype
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError("FlashNeoX is only available on GPU")

Expand Down
3 changes: 2 additions & 1 deletion server/text_generation_server/models/flash_rw.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,10 @@ def __init__(
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.float16 if dtype is None else dtype
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError("FlashRW is only available on GPU")

Expand Down
3 changes: 2 additions & 1 deletion server/text_generation_server/models/flash_santacoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,10 @@ def __init__(
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.float16 if dtype is None else dtype
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError("FlashSantacoderSharded is only available on GPU")

Expand Down

0 comments on commit e563983

Please sign in to comment.