Skip to content

Commit

Permalink
modify device map
Browse files Browse the repository at this point in the history
  • Loading branch information
ivy-lv11 committed May 23, 2024
1 parent 9d8dab4 commit 1a74b06
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -69,20 +69,20 @@ def messages_to_prompt(messages):
device_map=device,
)

print("----------------- Complete ------------------")
print("\n----------------- Complete ------------------")
completion_response = llm.complete(query)
print(completion_response.text)
print("----------------- Stream Complete ------------------")
print("\n----------------- Stream Complete ------------------")
response_iter = llm.stream_complete(query)
for response in response_iter:
print(response.delta, end="", flush=True)
print("----------------- Chat ------------------")
print("\n----------------- Chat ------------------")
from llama_index.core.llms import ChatMessage

message = ChatMessage(role="user", content=query)
resp = llm.chat([message])
print(resp)
print("----------------- Stream Chat ------------------")
print("\n----------------- Stream Chat ------------------")
message = ChatMessage(role="user", content=query)
resp = llm.stream_chat([message], max_tokens=256)
for r in resp:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import logging
from threading import Thread
from typing import Any, Callable, List, Optional, Sequence
from typing import Any, Callable, List, Optional, Sequence, Literal

import torch
from llama_index.core.base.llms.types import (
Expand Down Expand Up @@ -94,7 +94,7 @@ class IpexLLM(CustomLLM):
),
)
device_map: str = Field(
default="auto", description="The device_map to use. Defaults to 'auto'."
default="cpu", description="The device_map to use. Defaults to 'cpu'."
)
stopping_ids: List[int] = Field(
default_factory=list,
Expand Down Expand Up @@ -145,7 +145,7 @@ def __init__(
load_in_low_bit: Optional[str] = None,
model: Optional[Any] = None,
tokenizer: Optional[Any] = None,
device_map: Optional[str] = "auto",
device_map: Literal["cpu", "xpu"] = "cpu",
stopping_ids: Optional[List[int]] = None,
tokenizer_kwargs: Optional[dict] = None,
tokenizer_outputs_to_remove: Optional[list] = None,
Expand All @@ -171,7 +171,7 @@ def __init__(
Unused if `model` is passed in directly.
model: The HuggingFace model.
tokenizer: The tokenizer.
device_map: The device_map to use. Defaults to 'auto'.
device_map: The device_map to use. Defaults to 'cpu'.
stopping_ids: The stopping ids to use.
Generation stops when these token IDs are predicted.
tokenizer_kwargs: The kwargs to pass to the tokenizer.
Expand All @@ -197,7 +197,11 @@ def __init__(
self._model = self._load_model(
low_bit_model, load_in_4bit, load_in_low_bit, model_name, model_kwargs
)

if device_map not in ["cpu", "xpu"]:
raise ValueError(
"IpexLLMEmbedding currently only supports device to be 'cpu' or 'xpu', "
f"but you have: {device_map}."
)
if "xpu" in device_map:
self._model = self._model.to(device_map)

Expand Down

0 comments on commit 1a74b06

Please sign in to comment.