Skip to content

Commit

Permalink
Removing IPEX_AVAIL. (#2115)
Browse files Browse the repository at this point in the history
* Removing IPEX_AVAIL.

Chose to unify CPU and XPU under `ipex`. Most code is exactly similar
except for a very few spots.

The biggest number of spots is the kv-cache layout and the flash_xxx.py
files.
Since those files should be removed soon and factored away, we should
not need them.

* Forgot a few places.

* Unrelated change.

* Fixing HF_TOKEN.

* HF_TOKEN
  • Loading branch information
Narsil authored Jun 25, 2024
1 parent 3f3b7ff commit 9e2fdf5
Show file tree
Hide file tree
Showing 22 changed files with 79 additions and 78 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,6 @@ jobs:
export DOCKER_VOLUME=/mnt/cache
export DOCKER_IMAGE=${{ needs.build-and-push.outputs.docker_image }}
export DOCKER_DEVICES=${{ needs.build-and-push.outputs.docker_devices }}
export HF_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
export HF_TOKEN=${{ secrets.HF_TOKEN }}
echo $DOCKER_IMAGE
pytest -s -vv integration-tests
2 changes: 1 addition & 1 deletion .github/workflows/client-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,5 @@ jobs:
- name: Run tests
run: |
pip install pytest pytest-asyncio
export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
export HF_TOKEN=${{ secrets.HF_TOKEN }}
make python-client-tests
2 changes: 1 addition & 1 deletion .github/workflows/integration_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,5 @@ jobs:
export DOCKER_VOLUME=/mnt/cache
export DOCKER_IMAGE=${{ inputs.docker_image }}
export DOCKER_DEVICES=${{ inputs.docker_devices }}
export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
export HF_TOKEN=${{ secrets.HF_TOKEN }}
pytest -s -vv integration-tests
2 changes: 1 addition & 1 deletion .github/workflows/load_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
- name: Start starcoder
run: |
docker run --name tgi-starcoder --rm --gpus all -p 3000:80 -v /mnt/cache:/data -e HF_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }} --pull always -d ghcr.io/huggingface/text-generation-inference:latest --model-id bigcode/starcoder --num-shard 2 --max-batch-total-tokens 32768
docker run --name tgi-starcoder --rm --gpus all -p 3000:80 -v /mnt/cache:/data -e HF_TOKEN=${{ secrets.HF_TOKEN }} --pull always -d ghcr.io/huggingface/text-generation-inference:latest --model-id bigcode/starcoder --num-shard 2 --max-batch-total-tokens 32768
sleep 10
wget --timeout 10 --retry-on-http-error --waitretry=1 --tries=240 http://localhost:3000/health
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ jobs:
- name: Run server tests
run: |
pip install pytest
export HF_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
export HF_TOKEN=${{ secrets.HF_TOKEN }}
pytest -s -vv server/tests
- name: Pre-commit checks
run: |
Expand Down
2 changes: 1 addition & 1 deletion clients/python/text_generation/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,6 @@ class DeployedModel(BaseModel):
# Disable warning for use of `model_` prefix in `model_id`. Be mindful about adding members
# with model_ prefixes, since this disables guardrails for colliding fields:
# https://github.com/pydantic/pydantic/issues/9177
model_config = ConfigDict(protected_namespaces=())
model_config = ConfigDict(protected_namespaces=())
model_id: str
sha: str
6 changes: 3 additions & 3 deletions server/text_generation_server/layers/attention/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL
from text_generation_server.utils.import_utils import SYSTEM
import os

if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
Expand All @@ -7,7 +7,7 @@
from .cuda import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
elif SYSTEM == "rocm":
from .rocm import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
elif IPEX_AVAIL:
from .xpu import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
elif SYSTEM == "ipex":
from .ipex import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
else:
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")
5 changes: 2 additions & 3 deletions server/text_generation_server/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from accelerate import init_empty_weights
from text_generation_server.utils.import_utils import (
SYSTEM,
IPEX_AVAIL,
)


Expand Down Expand Up @@ -83,7 +82,7 @@ def forward(self, hidden_states, residual=None):

return super().forward(hidden_states), residual

elif IPEX_AVAIL:
elif SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex

class FastLayerNorm(nn.LayerNorm):
Expand Down Expand Up @@ -112,7 +111,7 @@ def load(cls, prefix, weights, eps=1e-6):
return cls(weight, eps)

def forward(self, hidden_states, residual=None):
if IPEX_AVAIL:
if SYSTEM == "ipex":
out = ipex.llm.functional.add_rms_norm(
residual,
hidden_states,
Expand Down
6 changes: 3 additions & 3 deletions server/text_generation_server/layers/rotary.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
import torch
from torch import nn

from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL
from text_generation_server.utils.import_utils import SYSTEM

if SYSTEM == "cuda":
from flash_attn.layers.rotary import RotaryEmbedding
import rotary_emb
elif SYSTEM == "rocm":
from vllm._C import ops
elif IPEX_AVAIL:
elif SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex


Expand Down Expand Up @@ -69,7 +69,7 @@ def forward(

# Inplace operation, updating query and key.
ops.rotary_embedding(query, key, head_size, cos, sin, True)
elif IPEX_AVAIL:
elif SYSTEM == "ipex":
ipex.llm.functional.rotary_embedding(
query, key, sin, cos, query.size(-1), True
)
Expand Down
12 changes: 6 additions & 6 deletions server/text_generation_server/layers/tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from typing import Iterable, List
from text_generation_server.layers.linear import get_linear, FastLinear
from text_generation_server.layers.exl2 import Exl2Weight
from text_generation_server.utils.import_utils import IPEX_AVAIL
from text_generation_server.utils.import_utils import SYSTEM

if IPEX_AVAIL:
if SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex


Expand Down Expand Up @@ -100,7 +100,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
local_out = gather_input.T

torch.mm(input, self.linear.weight.T, out=local_out)
if IPEX_AVAIL:
if SYSTEM == "ipex":
ipex.distributed.all_gather_into_tensor(
world_out, gather_input, group=self.process_group
)
Expand All @@ -117,7 +117,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
world_output = [
torch.empty_like(output) for _ in range(self.process_group.size())
]
if IPEX_AVAIL:
if SYSTEM == "ipex":
ipex.distributed.all_gather(world_output, output, group=self.process_group)
else:
torch.distributed.all_gather(world_output, output, group=self.process_group)
Expand Down Expand Up @@ -217,7 +217,7 @@ def load(cls, config, prefix: str, weights, bias: bool):
def forward(self, input: torch.Tensor, reduce: bool = True) -> torch.Tensor:
out = super().forward(input)
if self.process_group.size() > 1 and reduce:
if IPEX_AVAIL:
if SYSTEM == "ipex":
ipex.distributed.all_reduce(out, group=self.process_group)
else:
torch.distributed.all_reduce(out, group=self.process_group)
Expand Down Expand Up @@ -257,7 +257,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
)
out = torch.nn.functional.embedding(input, self.weight)
if self.reduce and self.process_group.size() > 1:
if IPEX_AVAIL:
if SYSTEM == "ipex":
ipex.distributed.all_reduce(out, group=self.process_group)
else:
torch.distributed.all_reduce(out, group=self.process_group)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple, Any
from text_generation_server.utils.import_utils import IPEX_AVAIL
from text_generation_server.utils.import_utils import SYSTEM

if not IPEX_AVAIL:
if SYSTEM != "ipex":
from vllm.model_executor.layers.fused_moe import fused_moe

from text_generation_server.layers.attention import (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
import numpy as np

from torch import nn
from text_generation_server.utils.import_utils import IPEX_AVAIL
from text_generation_server.utils.import_utils import SYSTEM

if not IPEX_AVAIL:
if SYSTEM != "ipex":
from vllm.model_executor.layers.fused_moe import fused_moe
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
Expand Down
9 changes: 3 additions & 6 deletions server/text_generation_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from text_generation_server.utils.chunks import concat_text_chunks
from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models import Model
from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.utils.dist import RANK
Expand Down Expand Up @@ -768,12 +768,9 @@ def init_kv_cache(
empty_cache()

element_size = torch.tensor([], dtype=dtype).element_size()
if SYSTEM == "xpu":
x = 1
else:
x = BLOCK_SIZE // element_size
x = BLOCK_SIZE // element_size

if IPEX_AVAIL and SYSTEM == "cpu":
if SYSTEM == "ipex" and device == torch.device("cpu"):
self.kv_cache = [
(
torch.empty(
Expand Down
12 changes: 6 additions & 6 deletions server/text_generation_server/models/flash_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
weight_files,
Weights,
)
from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL
from text_generation_server.utils.import_utils import SYSTEM

tracer = trace.get_tracer(__name__)

Expand All @@ -34,12 +34,12 @@ def __init__(
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
elif SYSTEM == "xpu":
device = torch.device(f"xpu:{rank}")
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
else:
device = torch.device("cpu")
dtype = torch.float16 if dtype is None else dtype
elif IPEX_AVAIL:
device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError("FlashGPT2 is only available on GPU")

Expand Down
12 changes: 6 additions & 6 deletions server/text_generation_server/models/flash_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

tracer = trace.get_tracer(__name__)

from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL
from text_generation_server.utils.import_utils import SYSTEM


class FlashLlama(FlashCausalLM):
Expand All @@ -34,12 +34,12 @@ def __init__(
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
elif SYSTEM == "xpu":
device = torch.device(f"xpu:{rank}")
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
else:
device = torch.device("cpu")
dtype = torch.float16 if dtype is None else dtype
elif IPEX_AVAIL:
device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError("FlashLlama is only available on GPU")

Expand Down
12 changes: 6 additions & 6 deletions server/text_generation_server/models/flash_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
weight_files,
Weights,
)
from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL
from text_generation_server.utils.import_utils import SYSTEM

tracer = trace.get_tracer(__name__)

Expand All @@ -38,12 +38,12 @@ def __init__(
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
elif SYSTEM == "xpu":
device = torch.device(f"xpu:{rank}")
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
else:
device = torch.device("cpu")
dtype = torch.float16 if dtype is None else dtype
elif IPEX_AVAIL:
device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError("FlashMistral is only available on GPU")

Expand Down
12 changes: 6 additions & 6 deletions server/text_generation_server/models/flash_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
weight_files,
Weights,
)
from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL
from text_generation_server.utils.import_utils import SYSTEM

tracer = trace.get_tracer(__name__)

Expand All @@ -33,12 +33,12 @@ def __init__(
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
elif SYSTEM == "xpu":
device = torch.device(f"xpu:{rank}")
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
else:
device = torch.device("cpu")
dtype = torch.float16 if dtype is None else dtype
elif IPEX_AVAIL:
device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError("FlashNeoX is only available on GPU")

Expand Down
12 changes: 6 additions & 6 deletions server/text_generation_server/models/flash_rw.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
weight_files,
Weights,
)
from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL
from text_generation_server.utils.import_utils import SYSTEM

tracer = trace.get_tracer(__name__)

Expand All @@ -34,12 +34,12 @@ def __init__(
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
elif SYSTEM == "xpu":
device = torch.device(f"xpu:{rank}")
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
else:
device = torch.device("cpu")
dtype = torch.float16 if dtype is None else dtype
elif IPEX_AVAIL:
device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError("FlashRW is only available on GPU")

Expand Down
12 changes: 6 additions & 6 deletions server/text_generation_server/models/flash_santacoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
Weights,
)

from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL
from text_generation_server.utils.import_utils import SYSTEM

tracer = trace.get_tracer(__name__)

Expand All @@ -37,12 +37,12 @@ def __init__(
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
elif SYSTEM == "xpu":
device = torch.device(f"xpu:{rank}")
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
else:
device = torch.device("cpu")
dtype = torch.float16 if dtype is None else dtype
elif IPEX_AVAIL:
device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError("FlashSantacoderSharded is only available on GPU")

Expand Down
4 changes: 2 additions & 2 deletions server/text_generation_server/utils/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from datetime import timedelta
from loguru import logger
from text_generation_server.utils.import_utils import IPEX_AVAIL
from text_generation_server.utils.import_utils import SYSTEM

# Tensor Parallelism settings
RANK = int(os.getenv("RANK", "0"))
Expand Down Expand Up @@ -69,7 +69,7 @@ def initialize_torch_distributed():

if not torch.distributed.is_initialized():
# Call the init process.
if IPEX_AVAIL:
if SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex

ipex.distributed.init_process_group(
Expand Down
Loading

0 comments on commit 9e2fdf5

Please sign in to comment.