diff --git a/fastchat/serve/vllm_worker.py b/fastchat/serve/vllm_worker.py index 8e255b79c..b5dca3053 100644 --- a/fastchat/serve/vllm_worker.py +++ b/fastchat/serve/vllm_worker.py @@ -8,16 +8,21 @@ import asyncio import json from typing import List - from fastapi import FastAPI, Request, BackgroundTasks from fastapi.responses import StreamingResponse, JSONResponse import torch import uvicorn -from vllm import AsyncLLMEngine +import vllm +from packaging import version from vllm.engine.arg_utils import AsyncEngineArgs from vllm.sampling_params import SamplingParams from vllm.utils import random_uuid +if version.parse(vllm.__version__) >= version.parse("0.2.0"): + from vllm.engine.async_llm_engine import AsyncLLMEngine +else: + from vllm import AsyncLLMEngine + from fastchat.serve.model_worker import ( BaseModelWorker, logger,