Skip to content

Commit

Permalink
[core][misc] keep compatibility for old-style classes (#10356)
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <[email protected]>
  • Loading branch information
youkaichao authored Nov 15, 2024
1 parent f2056f7 commit 3a763ba
Showing 1 changed file with 28 additions and 12 deletions.
40 changes: 28 additions & 12 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,18 +94,34 @@ def _initialize_model(vllm_config: VllmConfig, prefix: str = "") -> nn.Module:
model_config = vllm_config.model_config
model_class, _ = get_model_architecture(model_config)
signatures = inspect.signature(model_class.__init__)
# collect all kw-only parameters
kw_only_params = [
param.name for param in signatures.parameters.values()
if param.kind == inspect.Parameter.KEYWORD_ONLY
]
assert "vllm_config" in kw_only_params and "prefix" in kw_only_params, \
("vLLM model class must accept `vllm_config` and `prefix` as kw-only "
"arguments. Possibly you have an old-style model class registered from "
"out of tree and it is used for new vLLM version. "
"Please check https://docs.vllm.ai/en/latest/design/class_hierarchy.html "
"for the design and update the model class accordingly.")
return model_class(vllm_config=vllm_config, prefix=prefix)
all_params = [param.name for param in signatures.parameters.values()]
if "vllm_config" in all_params and "prefix" in all_params:
# new-style model class
return model_class(vllm_config=vllm_config, prefix=prefix)
msg = ("vLLM model class should accept `vllm_config` and `prefix` as "
"input arguments. Possibly you have an old-style model class"
" registered from out of tree and it is used for new vLLM version. "
"Check https://docs.vllm.ai/en/latest/design/class_hierarchy.html "
"for the design and update the model class accordingly.")
logger.warning(msg)
logger.warning(
"Trying to guess the arguments for old-style model class %s",
model_class)
# try to be compatible with old-style model class
kwargs = {}
if "prefix" in all_params:
kwargs["prefix"] = prefix
if "config" in all_params:
kwargs["config"] = model_config.hf_config
if "cache_config" in all_params:
kwargs["cache_config"] = vllm_config.cache_config
if "quant_config" in all_params:
kwargs["quant_config"] = vllm_config.quant_config
if "lora_config" in all_params:
kwargs["lora_config"] = vllm_config.lora_config
if "scheduler_config" in all_params:
kwargs["scheduler_config"] = vllm_config.scheduler_config
return model_class(**kwargs)


class BaseModelLoader(ABC):
Expand Down

0 comments on commit 3a763ba

Please sign in to comment.