Skip to content

Commit

Permalink
lmdeploy 支持 dtype 配置
Browse files Browse the repository at this point in the history
  • Loading branch information
shell-nlp committed Oct 30, 2024
1 parent 264b612 commit 8115fa9
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
9 changes: 8 additions & 1 deletion gpt_server/model_backend/lmdeploy_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,21 @@ def __init__(self, model_path) -> None:
backend = backend_map[os.getenv("backend")]
enable_prefix_caching = bool(os.getenv("enable_prefix_caching", False))
max_model_len = os.getenv("max_model_len", None)
dtype = os.getenv("dtype", "auto")
logger.info(f"后端 {backend}")
if backend == "pytorch":
backend_config = PytorchEngineConfig(tp=int(os.getenv("num_gpus", "1")))
backend_config = PytorchEngineConfig(
tp=int(os.getenv("num_gpus", "1")),
dtype=dtype,
session_len=int(max_model_len) if max_model_len else None,
enable_prefix_caching=enable_prefix_caching,
)
if backend == "turbomind":
backend_config = TurbomindEngineConfig(
tp=int(os.getenv("num_gpus", "1")),
enable_prefix_caching=enable_prefix_caching,
session_len=int(max_model_len) if max_model_len else None,
dtype=dtype,
)
pipeline_type, pipeline_class = get_task(model_path)
logger.info(f"模型架构:{pipeline_type}")
Expand Down
2 changes: 0 additions & 2 deletions gpt_server/model_backend/vllm_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,7 @@ class VllmBackend(ModelBackend):
def __init__(self, model_path) -> None:
lora = os.getenv("lora", None)
enable_prefix_caching = bool(os.getenv("enable_prefix_caching", False))

max_model_len = os.getenv("max_model_len", None)

tensor_parallel_size = int(os.getenv("num_gpus", "1"))
dtype = os.getenv("dtype", "auto")
max_loras = 1
Expand Down

0 comments on commit 8115fa9

Please sign in to comment.