Skip to content

Commit

Permalink
临时提交
Browse files Browse the repository at this point in the history
  • Loading branch information
icowan committed Sep 30, 2024
1 parent d37b1ee commit 8959971
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 7 deletions.
2 changes: 1 addition & 1 deletion fastchat/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ def get_images(self):
if i % 2 == 0:
if type(msg) is tuple:
for image in msg[1]:
images.append(image.base64_str)
images.append(image)

return images

Expand Down
1 change: 1 addition & 0 deletions fastchat/protocol/openai_api_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class ChatCompletionRequest(BaseModel):
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0
user: Optional[str] = None
seed: Optional[int] = None


class ChatMessage(BaseModel):
Expand Down
3 changes: 3 additions & 0 deletions fastchat/serve/openai_api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,7 @@ async def create_chat_completion(request: ChatCompletionRequest):
return error_check_ret

worker_addr = await get_worker_address(request.model)
logger.info(f"worker_addr: {worker_addr}")

gen_params = await get_gen_params(
request.model,
Expand All @@ -444,6 +445,8 @@ async def create_chat_completion(request: ChatCompletionRequest):
seed=request.seed,
)

print(gen_params["prompt"])

max_new_tokens, error_check_ret = await check_length(
request,
gen_params["prompt"],
Expand Down
54 changes: 48 additions & 6 deletions fastchat/serve/vllm_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid
from vllm.inputs import TextPrompt
from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset
from vllm.utils import FlexibleArgumentParser

from fastchat.conversation import IMAGE_PLACEHOLDER_STR
from fastchat.serve.base_model_worker import BaseModelWorker
Expand Down Expand Up @@ -88,6 +91,39 @@ def replace_placeholders_with_images(prompt: str, placeholder: str, images: List
return prompt


def get_multi_modal_input(args):
"""
return {
"data": image or video,
"question": question,
}
"""
if args.modality == "image":
# Input image and question
image = ImageAsset("cherry_blossom") \
.pil_image.convert("RGB")
img_question = "What is the content of this image?"

return {
"data": image,
"question": img_question,
}

if args.modality == "video":
# Input video and question
video = VideoAsset(name="sample_demo_1.mp4",
num_frames=args.num_frames).np_ndarrays
vid_question = "Why is this video funny?"

return {
"data": video,
"question": vid_question,
}

msg = f"Modality {args.modality} is not supported."
raise ValueError(msg)


class VLLMWorker(BaseModelWorker):
def __init__(
self,
Expand Down Expand Up @@ -214,13 +250,18 @@ async def generate_stream(self, params):
skip_special_tokens = params.get("skip_special_tokens", True)

request = params.get("request", None)
image_token = params.get("image_token", IMAGE_PLACEHOLDER_STR)

if images is None:
images = []


# split prompt by image token
# split_prompt = prompt.split("<image>")
# if prompt.count("<image>") != len(images):
# raise ValueError(
# "The number of images passed in does not match the number of <image> tokens in the prompt!"
# )
# split_prompt = prompt.split(IMAGE_PLACEHOLDER_STR)
if prompt.count(image_token) != len(images):
raise ValueError(
"The number of images passed in does not match the number of <image> tokens in the prompt!"
)

# context: List[TextPrompt] = []
# for i in range(len(split_prompt)):
Expand All @@ -232,7 +273,8 @@ async def generate_stream(self, params):
"prompt": prompt,
}
if len(images) > 0:
context["multi_modal_data"] = {"image": load_image(images[0])},
# context["multi_modal_data"] = {"image": [load_image(url) for url in images]},
context["multi_modal_data"] = {"image": [load_image(url) for url in images]},

# Handle stop_str
stop = set()
Expand Down

0 comments on commit 8959971

Please sign in to comment.