diff --git a/Dockerfile b/Dockerfile index 153bff9cf565f..c981260a00d79 100644 --- a/Dockerfile +++ b/Dockerfile @@ -235,7 +235,9 @@ RUN mv vllm test_docs/ #################### OPENAI API SERVER #################### # openai api server alternative -FROM vllm-base AS vllm-openai + +# define sagemaker first, so it is not default from `docker build` +FROM vllm-base AS vllm-sagemaker # install additional dependencies for openai api server RUN --mount=type=cache,target=/root/.cache/pip \ @@ -247,5 +249,11 @@ RUN --mount=type=cache,target=/root/.cache/pip \ ENV VLLM_USAGE_SOURCE production-docker-image +# port 8080 required by sagemaker, https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html#your-algorithms-inference-code-container-response +ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server", "--port", "8080"] + +# the default image is `vllm-openai` which is identical to sagemaker without forcing the port +from vllm-sagemaker as vllm-openai + ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"] #################### OPENAI API SERVER #################### diff --git a/docs/source/conf.py b/docs/source/conf.py index 1fe0474631140..71394c5302a39 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -191,6 +191,7 @@ def linkcode_resolve(domain, info): # Mock out external dependencies here, otherwise the autodoc pages may be blank. autodoc_mock_imports = [ + "blake3", "compressed_tensors", "cpuinfo", "cv2", @@ -207,7 +208,7 @@ def linkcode_resolve(domain, info): "tensorizer", "pynvml", "outlines", - "xgrammar," + "xgrammar", "librosa", "soundfile", "gguf", diff --git a/docs/source/design/multimodal/multimodal_index.md b/docs/source/design/multimodal/multimodal_index.md index 88af07afc7018..e4f2171e84ff7 100644 --- a/docs/source/design/multimodal/multimodal_index.md +++ b/docs/source/design/multimodal/multimodal_index.md @@ -45,39 +45,39 @@ adding_multimodal_plugin ### Base Classes ```{eval-rst} -.. autodata:: vllm.multimodal.NestedTensors +.. automodule:: vllm.multimodal.base + :members: + :show-inheritance: ``` -```{eval-rst} -.. autodata:: vllm.multimodal.BatchedTensorInputs -``` +### Input Classes ```{eval-rst} -.. autoclass:: vllm.multimodal.MultiModalDataBuiltins +.. automodule:: vllm.multimodal.inputs :members: :show-inheritance: ``` -```{eval-rst} -.. autodata:: vllm.multimodal.MultiModalDataDict -``` +### Audio Classes ```{eval-rst} -.. autoclass:: vllm.multimodal.MultiModalKwargs +.. automodule:: vllm.multimodal.audio :members: :show-inheritance: ``` +### Image Classes + ```{eval-rst} -.. autoclass:: vllm.multimodal.MultiModalPlugin +.. automodule:: vllm.multimodal.image :members: :show-inheritance: ``` -### Image Classes +### Video Classes ```{eval-rst} -.. automodule:: vllm.multimodal.image +.. automodule:: vllm.multimodal.video :members: :show-inheritance: ``` diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index 95add0d71bbab..7acafda50793c 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -755,8 +755,7 @@ vLLM currently only supports adding LoRA to the language backbone of multimodal ``` ```{note} -To use {code}`TIGER-Lab/Mantis-8B-siglip-llama3`, you have to install their GitHub repo ({code}`pip install git+https://github.com/TIGER-AI-Lab/Mantis.git`) -and pass {code}`--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` when running vLLM. +To use {code}`TIGER-Lab/Mantis-8B-siglip-llama3`, you have pass {code}`--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` when running vLLM. ``` ```{note} diff --git a/docs/source/serving/deploying_with_k8s.md b/docs/source/serving/deploying_with_k8s.md index d27db826cd006..77f848088ea43 100644 --- a/docs/source/serving/deploying_with_k8s.md +++ b/docs/source/serving/deploying_with_k8s.md @@ -47,7 +47,11 @@ data: token: "REPLACE_WITH_TOKEN" ``` -Create a deployment file for vLLM to run the model server. The following example deploys the `Mistral-7B-Instruct-v0.3` model: +Next to create the deployment file for vLLM to run the model server. The following example deploys the `Mistral-7B-Instruct-v0.3` model. + +Here are two examples for using NVIDIA GPU and AMD GPU. + +- NVIDIA GPU ```yaml apiVersion: apps/v1 @@ -119,6 +123,79 @@ spec: periodSeconds: 5 ``` +- AMD GPU + +You can refer to the `deployment.yaml` below if using AMD ROCm GPU like MI300X. + +```yaml +apiVersion: apps/v1 +kind: Deployment +metadata: + name: mistral-7b + namespace: default + labels: + app: mistral-7b +spec: + replicas: 1 + selector: + matchLabels: + app: mistral-7b + template: + metadata: + labels: + app: mistral-7b + spec: + volumes: + # PVC + - name: cache-volume + persistentVolumeClaim: + claimName: mistral-7b + # vLLM needs to access the host's shared memory for tensor parallel inference. + - name: shm + emptyDir: + medium: Memory + sizeLimit: "8Gi" + hostNetwork: true + hostIPC: true + containers: + - name: mistral-7b + image: rocm/vllm:rocm6.2_mi300_ubuntu20.04_py3.9_vllm_0.6.4 + securityContext: + seccompProfile: + type: Unconfined + runAsGroup: 44 + capabilities: + add: + - SYS_PTRACE + command: ["/bin/sh", "-c"] + args: [ + "vllm serve mistralai/Mistral-7B-v0.3 --port 8000 --trust-remote-code --enable-chunked-prefill --max_num_batched_tokens 1024" + ] + env: + - name: HUGGING_FACE_HUB_TOKEN + valueFrom: + secretKeyRef: + name: hf-token-secret + key: token + ports: + - containerPort: 8000 + resources: + limits: + cpu: "10" + memory: 20G + amd.com/gpu: "1" + requests: + cpu: "6" + memory: 6G + amd.com/gpu: "1" + volumeMounts: + - name: cache-volume + mountPath: /root/.cache/huggingface + - name: shm + mountPath: /dev/shm +``` +You can get the full example with steps and sample yaml files from . + 2. **Create a Kubernetes Service for vLLM** Next, create a Kubernetes Service file to expose the `mistral-7b` deployment: diff --git a/docs/source/usage/structured_outputs.md b/docs/source/usage/structured_outputs.md index 3f5d9ffc26278..7292012e36a26 100644 --- a/docs/source/usage/structured_outputs.md +++ b/docs/source/usage/structured_outputs.md @@ -2,7 +2,7 @@ # Structured Outputs -vLLM supports the generation of structured outputs using [outlines](https://github.com/dottxt-ai/outlines) or [lm-format-enforcer](https://github.com/noamgat/lm-format-enforcer) as backends for the guided decoding. +vLLM supports the generation of structured outputs using [outlines](https://github.com/dottxt-ai/outlines), [lm-format-enforcer](https://github.com/noamgat/lm-format-enforcer), or [xgrammar](https://github.com/mlc-ai/xgrammar) as backends for the guided decoding. This document shows you some examples of the different options that are available to generate structured outputs. ## Online Inference (OpenAI API) diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index d5a71862656e7..77af914a6ef02 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -308,7 +308,20 @@ def run_mllama(question: str, modality: str): disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, ) - prompt = f"<|image|><|begin_of_text|>{question}" + tokenizer = AutoTokenizer.from_pretrained(model_name) + messages = [{ + "role": + "user", + "content": [{ + "type": "image" + }, { + "type": "text", + "text": f"{question}" + }] + }] + prompt = tokenizer.apply_chat_template(messages, + add_generation_prompt=True, + tokenize=False) stop_token_ids = None return llm, prompt, stop_token_ids diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index 51b255bb2a6db..61677b65af342 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -33,6 +33,7 @@ class MockModelConfig: hf_config = MockHFConfig() logits_processor_pattern = None diff_sampling_param: Optional[dict] = None + allowed_local_media_path: str = "" def get_diff_sampling_param(self): return self.diff_sampling_param or {} diff --git a/tests/entrypoints/openai/test_vision_embedding.py b/tests/entrypoints/openai/test_vision_embedding.py index 3731b2dcdeae1..c851539c610ec 100644 --- a/tests/entrypoints/openai/test_vision_embedding.py +++ b/tests/entrypoints/openai/test_vision_embedding.py @@ -91,5 +91,5 @@ async def test_image_embedding(server: RemoteOpenAIServer, model_name: str, assert len(embeddings.data) == 1 assert len(embeddings.data[0].embedding) == 3072 assert embeddings.usage.completion_tokens == 0 - assert embeddings.usage.prompt_tokens == 765 - assert embeddings.usage.total_tokens == 765 + assert embeddings.usage.prompt_tokens == 764 + assert embeddings.usage.total_tokens == 764 diff --git a/tests/entrypoints/test_chat_utils.py b/tests/entrypoints/test_chat_utils.py index 996e60bfee592..d63b963522e73 100644 --- a/tests/entrypoints/test_chat_utils.py +++ b/tests/entrypoints/test_chat_utils.py @@ -2,7 +2,6 @@ from typing import Optional import pytest -from PIL import Image from vllm.assets.image import ImageAsset from vllm.config import ModelConfig @@ -91,10 +90,7 @@ def _assert_mm_data_is_image_input( image_data = mm_data.get("image") assert image_data is not None - if image_count == 1: - assert isinstance(image_data, Image.Image) - else: - assert isinstance(image_data, list) and len(image_data) == image_count + assert isinstance(image_data, list) and len(image_data) == image_count def test_parse_chat_messages_single_image( diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index 8b247fb9b2388..57ebaa424fc59 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -4,6 +4,7 @@ from unittest.mock import MagicMock, patch import pytest +import safetensors import torch import torch.nn as nn from huggingface_hub import snapshot_download @@ -169,6 +170,29 @@ def mixtral_lora_files_all_target_modules(): return snapshot_download(repo_id="dyang415/mixtral-lora-v0") +@pytest.fixture(scope="session") +def jamba_lora_files(): + # some of the adapters have unnecessary weights for serving, + # hence we remove them + def remove_unnecessary_weights(path): + lora_path = f"{adapter_path}/adapter_model.safetensors" + tensors = safetensors.torch.load_file(lora_path) + nonlora_keys = [] + for k in list(tensors.keys()): + if "lora" not in k: + nonlora_keys.append(k) + for k in nonlora_keys: + del tensors[k] + safetensors.torch.save_file(tensors, lora_path) + + adapter_path = snapshot_download( + repo_id= + "hf-100/Jamba-1.5-mini-Spellbound-StoryWriter-0.1-6583896-ckpt53-lora") + + remove_unnecessary_weights(adapter_path) + return adapter_path + + @pytest.fixture(scope="session") def gemma_lora_files(): return snapshot_download(repo_id="wskwon/gemma-7b-test-lora") diff --git a/tests/lora/test_jamba.py b/tests/lora/test_jamba.py new file mode 100644 index 0000000000000..6aa33926cb6b8 --- /dev/null +++ b/tests/lora/test_jamba.py @@ -0,0 +1,54 @@ +from typing import List + +import pytest +import torch + +import vllm +from vllm.lora.request import LoRARequest + +MODEL_PATH = "ai21labs/AI21-Jamba-1.5-Mini" + +MAX_TOKENS = 40 + + +def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int, + prompts: List[str]) -> List[str]: + + sampling_params = vllm.SamplingParams(temperature=0, max_tokens=MAX_TOKENS) + outputs = llm.generate( + prompts, + sampling_params, + lora_request=LoRARequest(str(lora_id), lora_id, lora_path) + if lora_id else None) + # Print the outputs. + generated_texts: List[str] = [] + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text.strip() + generated_texts.append(generated_text) + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + return generated_texts + + +@pytest.mark.parametrize("tp_size", [4]) +def test_jamba_lora(jamba_lora_files, tp_size): + """Original test, the LoRA model has the common target modules, not all""" + if torch.cuda.device_count() < tp_size: + pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}") + + prompts = ["Write a story about a sheep and a goat."] + + llm = vllm.LLM( + MODEL_PATH, + enable_lora=True, + max_num_seqs=16, + max_loras=4, + distributed_executor_backend="ray", + tensor_parallel_size=tp_size, + ) + + expected_jamba_output = [ + """Once upon a time, in a lush green meadow, there lived a sheep named Clara and a goat named Billy. Clara was a gentle creature, always nibbling on the soft grass and humming""" # noqa: E501 + ] + assert do_sample(llm, jamba_lora_files, lora_id=1, + prompts=prompts) == expected_jamba_output diff --git a/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_qwen2_vl.py b/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_qwen2_vl.py index cd8954ffc48c2..5897c04c89e19 100644 --- a/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_qwen2_vl.py +++ b/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_qwen2_vl.py @@ -30,7 +30,7 @@ def get_max_qwen2_vl_image_tokens(): @pytest.mark.parametrize("mm_processor_kwargs,expected_max_tokens", [ - ({}, 1225), + ({}, 16384), ({ MIN_PIXELS: 64**2, MAX_PIXELS: 512**2 diff --git a/tests/models/decoder_only/vision_language/test_models.py b/tests/models/decoder_only/vision_language/test_models.py index 3101d1d2ea831..1a9c1b4ef1be0 100644 --- a/tests/models/decoder_only/vision_language/test_models.py +++ b/tests/models/decoder_only/vision_language/test_models.py @@ -201,6 +201,7 @@ vllm_output_post_proc=model_utils.fuyu_vllm_to_hf_output, num_logprobs=10, image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)], + marks=[large_gpu_mark(min_gb=48)], ), "glm4": VLMTestInfo( models=["THUDM/glm-4v-9b"], @@ -212,7 +213,7 @@ dtype="bfloat16", get_stop_token_ids=lambda tok: [151329, 151336, 151338], patch_hf_runner=model_utils.glm_patch_hf_runner, - marks=[large_gpu_mark(min_gb=48)], + marks=[large_gpu_mark(min_gb=32)], ), "h2ovl": VLMTestInfo( models = [ @@ -261,6 +262,7 @@ dtype="bfloat16", use_tokenizer_eos=True, patch_hf_runner=model_utils.internvl_patch_hf_runner, + marks=[large_gpu_mark(min_gb=32)], ), "llava_next": VLMTestInfo( models=["llava-hf/llava-v1.6-mistral-7b-hf"], diff --git a/tests/multimodal/test_processing.py b/tests/multimodal/test_processing.py index d22d778f81fa8..1b2847ed0f534 100644 --- a/tests/multimodal/test_processing.py +++ b/tests/multimodal/test_processing.py @@ -1,12 +1,20 @@ +from functools import partial from typing import cast +import numpy as np import pytest - -from vllm.multimodal.processing import (PromptReplacement, _PlaceholderInfo, - find_text_matches, find_token_matches, - iter_placeholders, iter_token_matches, +from PIL import Image + +from vllm.config import ModelConfig +from vllm.inputs import InputProcessingContext +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.processing import (ProcessingCache, PromptReplacement, + _PlaceholderInfo, find_text_matches, + find_token_matches, iter_placeholders, + iter_token_matches, replace_text_matches, replace_token_matches) +from vllm.multimodal.utils import cached_get_tokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import full_groupby @@ -457,6 +465,7 @@ def test_find_replace_tokens( ), ] ) +# yapf: enable def test_iter_placeholders( repl_by_key, prompt, @@ -475,11 +484,199 @@ def test_iter_placeholders( prompt_repls, prompt, # Effectively match all occurrences in the prompt - {key: 3 for key in repl_by_key}, - )) + {key: 3 + for key in repl_by_key}, + )) # Only displayed on error print("result:", result) # Manually constructed results assert result == expected + + +def _rand_img(rng: np.random.RandomState, min_wh: int, max_wh: int): + w, h = rng.randint(min_wh, max_wh, size=(2, )) + arr = rng.randint(0, 255, size=(w, h, 3), dtype=np.uint8) + return Image.fromarray(arr) + + +def _rand_video( + rng: np.random.RandomState, + min_frames: int, + max_frames: int, + min_wh: int, + max_wh: int, +): + # Temporary workaround for https://github.com/huggingface/transformers/issues/35412 + num_frames = rng.randint(min_frames, max_frames) + num_frames = (num_frames // 2) * 2 + + w, h = rng.randint(min_wh, max_wh, size=(2, )) + return rng.randint(0, 255, size=(num_frames, w, h, 3), dtype=np.uint8) + + +def _rand_audio( + rng: np.random.RandomState, + min_len: int, + max_len: int, + sr: int, +): + audio_len = rng.randint(min_len, max_len) + return rng.rand(audio_len), sr + + +def _test_processing_cache_correctness( + model_id: str, + modalities: set[str], + hit_rate: float, + num_batches: int, + simplify_rate: float, +): + if model_id == "TIGER-Lab/Mantis-8B-siglip-llama3": + hf_overrides = {"architectures": ["MantisForConditionalGeneration"]} + else: + hf_overrides = {} + + model_config = ModelConfig( + model_id, + task="auto", + tokenizer=model_id, + tokenizer_mode="auto", + trust_remote_code=True, + seed=0, + dtype="float16", + revision=None, + hf_overrides=hf_overrides, + ) + model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config) + + processor_factory = MULTIMODAL_REGISTRY._processor_factories[model_cls] + ctx = InputProcessingContext( + model_config, + tokenizer=cached_get_tokenizer(model_config.tokenizer), + ) + # Ensure that it can fit all of the data + cache = ProcessingCache(capacity=1 << 30) + + baseline_processor = processor_factory(ctx, cache=None) + cached_processor = processor_factory(ctx, cache=cache) + + rng = np.random.RandomState(0) + + input_to_hit = { + "image": Image.new("RGB", size=(128, 128)), + "video": np.zeros((4, 128, 128, 3), dtype=np.uint8), + "audio": (np.zeros((512, )), 16000), + } + input_factory = { + "image": + partial(_rand_img, rng, min_wh=128, max_wh=256), + "video": + partial(_rand_video, + rng, + min_frames=2, + max_frames=8, + min_wh=128, + max_wh=256), + "audio": + partial(_rand_audio, rng, min_len=256, max_len=512, sr=16000), + } + input_max_count = { + "image": 3, + "video": 3, + "audio": 3, + } + + for batch_idx in range(num_batches): + mm_data = { + k: + [(input_to_hit[k] if rng.rand() < hit_rate else input_factory[k]()) + for _ in range(rng.randint(input_max_count[k]))] + for k in modalities + } + + mm_counts = {k: len(vs) for k, vs in mm_data.items()} + prompt = baseline_processor._get_dummy_mm_inputs(mm_counts).prompt_text + + # Drop unnecessary keys and test single -> multi conversion + if rng.rand() < simplify_rate: + for k in list(mm_data.keys()): + if not mm_data[k]: + del mm_data[k] + elif len(mm_data[k]) == 1: + mm_data[k] = mm_data[k][0] + + baseline_result = baseline_processor.apply( + prompt, + mm_data=mm_data, + hf_processor_mm_kwargs={}, + ) + cached_result = cached_processor.apply( + prompt, + mm_data=mm_data, + hf_processor_mm_kwargs={}, + ) + + assert baseline_result == cached_result, ( + f"Failed ({batch_idx=}, {mm_data=})") + + +# yapf: disable +@pytest.mark.parametrize(("model_id", "modalities"), [ + ("llava-hf/llava-1.5-7b-hf", {"image"}), + ("TIGER-Lab/Mantis-8B-siglip-llama3", {"image"}), + ("mistral-community/pixtral-12b", {"image"}), + ("Qwen/Qwen2-VL-2B-Instruct", {"image", "video"}), + ("Qwen/Qwen2-Audio-7B-Instruct", {"audio"}), + ("fixie-ai/ultravox-v0_3", {"audio"}), +]) +@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0]) +@pytest.mark.parametrize("num_batches", [32]) +@pytest.mark.parametrize("simplify_rate", [1.0]) +# yapf: enable +def test_processing_cache_correctness( + model_id: str, + modalities: set[str], + hit_rate: float, + num_batches: int, + simplify_rate: float, +): + _test_processing_cache_correctness( + model_id, + modalities, + hit_rate=hit_rate, + num_batches=num_batches, + simplify_rate=simplify_rate, + ) + + +# yapf: disable +@pytest.mark.parametrize(("model_id", "modalities"), [ + ("microsoft/Phi-3-vision-128k-instruct", {"image"}), +]) +@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0]) +@pytest.mark.parametrize("num_batches", [32]) +@pytest.mark.parametrize("simplify_rate", [1.0]) +# yapf: enable +def test_processing_cache_correctness_phi3v( + model_id: str, + modalities: set[str], + hit_rate: float, + num_batches: int, + simplify_rate: float, +): + # HACK - this is an attempted workaround for the following bug + # https://github.com/huggingface/transformers/issues/34307 + from transformers import AutoImageProcessor # noqa: F401 + from transformers import AutoProcessor # noqa: F401 + + AutoImageProcessor.from_pretrained(model_id, trust_remote_code=True) + + _test_processing_cache_correctness( + model_id, + modalities, + hit_rate=hit_rate, + num_batches=num_batches, + simplify_rate=simplify_rate, + ) diff --git a/tests/multimodal/test_utils.py b/tests/multimodal/test_utils.py index fd82fb0c55fd7..6029f2e514772 100644 --- a/tests/multimodal/test_utils.py +++ b/tests/multimodal/test_utils.py @@ -9,7 +9,7 @@ from PIL import Image, ImageChops from transformers import AutoConfig, AutoTokenizer -from vllm.multimodal.utils import (async_fetch_image, fetch_image, +from vllm.multimodal.utils import (MediaConnector, repeat_and_pad_placeholder_tokens) # Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA) @@ -23,7 +23,12 @@ @pytest.fixture(scope="module") def url_images() -> Dict[str, Image.Image]: - return {image_url: fetch_image(image_url) for image_url in TEST_IMAGE_URLS} + connector = MediaConnector() + + return { + image_url: connector.fetch_image(image_url) + for image_url in TEST_IMAGE_URLS + } def get_supported_suffixes() -> Tuple[str, ...]: @@ -43,8 +48,10 @@ def _image_equals(a: Image.Image, b: Image.Image) -> bool: @pytest.mark.asyncio @pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) async def test_fetch_image_http(image_url: str): - image_sync = fetch_image(image_url) - image_async = await async_fetch_image(image_url) + connector = MediaConnector() + + image_sync = connector.fetch_image(image_url) + image_async = await connector.fetch_image_async(image_url) assert _image_equals(image_sync, image_async) @@ -53,6 +60,7 @@ async def test_fetch_image_http(image_url: str): @pytest.mark.parametrize("suffix", get_supported_suffixes()) async def test_fetch_image_base64(url_images: Dict[str, Image.Image], image_url: str, suffix: str): + connector = MediaConnector() url_image = url_images[image_url] try: @@ -75,48 +83,49 @@ async def test_fetch_image_base64(url_images: Dict[str, Image.Image], base64_image = base64.b64encode(f.read()).decode("utf-8") data_url = f"data:{mime_type};base64,{base64_image}" - data_image_sync = fetch_image(data_url) + data_image_sync = connector.fetch_image(data_url) if _image_equals(url_image, Image.open(f)): assert _image_equals(url_image, data_image_sync) else: pass # Lossy format; only check that image can be opened - data_image_async = await async_fetch_image(data_url) + data_image_async = await connector.fetch_image_async(data_url) assert _image_equals(data_image_sync, data_image_async) @pytest.mark.asyncio @pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) async def test_fetch_image_local_files(image_url: str): + connector = MediaConnector() + with TemporaryDirectory() as temp_dir: - origin_image = fetch_image(image_url) + local_connector = MediaConnector(allowed_local_media_path=temp_dir) + + origin_image = connector.fetch_image(image_url) origin_image.save(os.path.join(temp_dir, os.path.basename(image_url)), quality=100, icc_profile=origin_image.info.get('icc_profile')) - image_async = await async_fetch_image( - f"file://{temp_dir}/{os.path.basename(image_url)}", - allowed_local_media_path=temp_dir) - - image_sync = fetch_image( - f"file://{temp_dir}/{os.path.basename(image_url)}", - allowed_local_media_path=temp_dir) + image_async = await local_connector.fetch_image_async( + f"file://{temp_dir}/{os.path.basename(image_url)}") + image_sync = local_connector.fetch_image( + f"file://{temp_dir}/{os.path.basename(image_url)}") # Check that the images are equal assert not ImageChops.difference(image_sync, image_async).getbbox() - with pytest.raises(ValueError): - await async_fetch_image( - f"file://{temp_dir}/../{os.path.basename(image_url)}", - allowed_local_media_path=temp_dir) - with pytest.raises(ValueError): - await async_fetch_image( + with pytest.raises(ValueError, match="must be a subpath"): + await local_connector.fetch_image_async( + f"file://{temp_dir}/../{os.path.basename(image_url)}") + with pytest.raises(RuntimeError, match="Cannot load local files"): + await connector.fetch_image_async( f"file://{temp_dir}/../{os.path.basename(image_url)}") - with pytest.raises(ValueError): - fetch_image(f"file://{temp_dir}/../{os.path.basename(image_url)}", - allowed_local_media_path=temp_dir) - with pytest.raises(ValueError): - fetch_image(f"file://{temp_dir}/../{os.path.basename(image_url)}") + with pytest.raises(ValueError, match="must be a subpath"): + local_connector.fetch_image( + f"file://{temp_dir}/../{os.path.basename(image_url)}") + with pytest.raises(RuntimeError, match="Cannot load local files"): + connector.fetch_image( + f"file://{temp_dir}/../{os.path.basename(image_url)}") @pytest.mark.parametrize("model", ["llava-hf/llava-v1.6-mistral-7b-hf"]) diff --git a/vllm/assets/audio.py b/vllm/assets/audio.py index 9033644e3264a..a46c67ad7e00e 100644 --- a/vllm/assets/audio.py +++ b/vllm/assets/audio.py @@ -21,12 +21,10 @@ class AudioAsset: name: Literal["winning_call", "mary_had_lamb"] @property - def audio_and_sample_rate(self) -> tuple[npt.NDArray, int]: + def audio_and_sample_rate(self) -> tuple[npt.NDArray, float]: audio_path = get_vllm_public_assets(filename=f"{self.name}.ogg", s3_prefix=ASSET_DIR) - y, sr = librosa.load(audio_path, sr=None) - assert isinstance(sr, int) - return y, sr + return librosa.load(audio_path, sr=None) @property def url(self) -> str: diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 3df08c740d65b..a492d5496e025 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -6,7 +6,7 @@ from functools import lru_cache, partial from pathlib import Path from typing import (Any, Awaitable, Callable, Dict, Generic, Iterable, List, - Literal, Mapping, Optional, Tuple, TypeVar, Union, cast) + Literal, Optional, Tuple, TypeVar, Union, cast) import jinja2.nodes import transformers.utils.chat_template_utils as hf_chat_utils @@ -23,6 +23,8 @@ ChatCompletionMessageParam as OpenAIChatCompletionMessageParam) from openai.types.chat import (ChatCompletionMessageToolCallParam, ChatCompletionToolMessageParam) +from openai.types.chat.chat_completion_content_part_input_audio_param import ( + InputAudio) # yapf: enable # pydantic needs the TypedDict from typing_extensions from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast @@ -31,11 +33,7 @@ from vllm.config import ModelConfig from vllm.logger import init_logger from vllm.multimodal import MultiModalDataDict -from vllm.multimodal.utils import (async_get_and_parse_audio, - async_get_and_parse_image, - async_get_and_parse_video, - get_and_parse_audio, get_and_parse_image, - get_and_parse_video) +from vllm.multimodal.utils import MediaConnector from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.utils import print_warning_once @@ -368,14 +366,17 @@ def __init__(self, model_config: ModelConfig, tokenizer: AnyTokenizer): self._tokenizer = tokenizer self._allowed_items = (model_config.multimodal_config.limit_per_prompt if model_config.multimodal_config else {}) - self._consumed_items = {k: 0 for k in self._allowed_items} - self._items: List[_T] = [] + self._items_by_modality = defaultdict[str, list[_T]](list) @property def model_config(self) -> ModelConfig: return self._model_config + @property + def allowed_local_media_path(self): + return self._model_config.allowed_local_media_path + @staticmethod @lru_cache(maxsize=None) def _cached_token_str(tokenizer: AnyTokenizer, token_index: int) -> str: @@ -435,38 +436,19 @@ def _placeholder_str(self, modality: ModalityStr, else: raise TypeError(f"Unknown modality: {modality}") - @staticmethod - def _combine(items: List[MultiModalDataDict]) -> MultiModalDataDict: - mm_lists: Mapping[str, List[object]] = defaultdict(list) - - # Merge all the multi-modal items - for single_mm_data in items: - for mm_key, mm_item in single_mm_data.items(): - if isinstance(mm_item, list): - mm_lists[mm_key].extend(mm_item) - else: - mm_lists[mm_key].append(mm_item) - - # Unpack any single item lists for models that don't expect multiple. - return { - mm_key: mm_list[0] if len(mm_list) == 1 else mm_list - for mm_key, mm_list in mm_lists.items() - } - def add(self, modality: ModalityStr, item: _T) -> Optional[str]: """ Add a multi-modal item to the current prompt and returns the placeholder string to use, if any. """ allowed_count = self._allowed_items.get(modality, 1) - current_count = self._consumed_items.get(modality, 0) + 1 + current_count = len(self._items_by_modality[modality]) + 1 if current_count > allowed_count: raise ValueError( f"At most {allowed_count} {modality}(s) may be provided in " "one request.") - self._consumed_items[modality] = current_count - self._items.append(item) + self._items_by_modality[modality].append(item) return self._placeholder_str(modality, current_count) @@ -475,22 +457,26 @@ def create_parser(self) -> "BaseMultiModalContentParser": raise NotImplementedError -class MultiModalItemTracker(BaseMultiModalItemTracker[MultiModalDataDict]): +class MultiModalItemTracker(BaseMultiModalItemTracker[object]): def all_mm_data(self) -> Optional[MultiModalDataDict]: - return self._combine(self._items) if self._items else None + if self._items_by_modality: + return dict(self._items_by_modality) + + return None def create_parser(self) -> "BaseMultiModalContentParser": return MultiModalContentParser(self) -class AsyncMultiModalItemTracker( - BaseMultiModalItemTracker[Awaitable[MultiModalDataDict]]): +class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]): async def all_mm_data(self) -> Optional[MultiModalDataDict]: - if self._items: - items = await asyncio.gather(*self._items) - return self._combine(items) + if self._items_by_modality: + return { + modality: await asyncio.gather(*items) + for modality, items in self._items_by_modality.items() + } return None @@ -522,7 +508,7 @@ def parse_audio(self, audio_url: str) -> None: raise NotImplementedError @abstractmethod - def parse_input_audio(self, input_audio: Dict[str, str]) -> None: + def parse_input_audio(self, input_audio: InputAudio) -> None: raise NotImplementedError @abstractmethod @@ -537,31 +523,31 @@ def __init__(self, tracker: MultiModalItemTracker) -> None: self._tracker = tracker + self._connector = MediaConnector( + allowed_local_media_path=tracker.allowed_local_media_path, + ) + def parse_image(self, image_url: str) -> None: - image = get_and_parse_image(image_url, - allowed_local_media_path=self._tracker. - _model_config.allowed_local_media_path) + image = self._connector.fetch_image(image_url) placeholder = self._tracker.add("image", image) self._add_placeholder(placeholder) def parse_audio(self, audio_url: str) -> None: - audio = get_and_parse_audio(audio_url) + audio = self._connector.fetch_audio(audio_url) placeholder = self._tracker.add("audio", audio) self._add_placeholder(placeholder) - def parse_input_audio(self, input_audio: Dict[str, str]) -> None: - input_audio_data = input_audio.get("data","") - input_audio_format = input_audio.get("format","") - audio_url = f"data:audio/{input_audio_format};base64,{input_audio_data}" - audio = get_and_parse_audio(audio_url) + def parse_input_audio(self, input_audio: InputAudio) -> None: + audio_data = input_audio.get("data", "") + audio_format = input_audio.get("format", "") + audio_url = f"data:audio/{audio_format};base64,{audio_data}" - placeholder = self._tracker.add("audio", audio) - self._add_placeholder(placeholder) + return self.parse_audio(audio_url) def parse_video(self, video_url: str) -> None: - video = get_and_parse_video(video_url) + video = self._connector.fetch_video(video_url) placeholder = self._tracker.add("video", video) self._add_placeholder(placeholder) @@ -573,33 +559,31 @@ def __init__(self, tracker: AsyncMultiModalItemTracker) -> None: super().__init__() self._tracker = tracker + self._connector = MediaConnector( + allowed_local_media_path=tracker.allowed_local_media_path, + ) def parse_image(self, image_url: str) -> None: - image_coro = async_get_and_parse_image( - image_url, - allowed_local_media_path=self._tracker._model_config. - allowed_local_media_path) + image_coro = self._connector.fetch_image_async(image_url) placeholder = self._tracker.add("image", image_coro) self._add_placeholder(placeholder) def parse_audio(self, audio_url: str) -> None: - audio_coro = async_get_and_parse_audio(audio_url) + audio_coro = self._connector.fetch_audio_async(audio_url) placeholder = self._tracker.add("audio", audio_coro) self._add_placeholder(placeholder) - def parse_input_audio(self, input_audio: Dict[str, str]) -> None: - input_audio_data = input_audio.get("data","") - input_audio_format = input_audio.get("format","") - audio_url = f"data:audio/{input_audio_format};base64,{input_audio_data}" - audio_coro = async_get_and_parse_audio(audio_url) + def parse_input_audio(self, input_audio: InputAudio) -> None: + audio_data = input_audio.get("data", "") + audio_format = input_audio.get("format", "") + audio_url = f"data:audio/{audio_format};base64,{audio_data}" - placeholder = self._tracker.add("audio", audio_coro) - self._add_placeholder(placeholder) + return self.parse_audio(audio_url) def parse_video(self, video_url: str) -> None: - video = async_get_and_parse_video(video_url) + video = self._connector.fetch_video_async(video_url) placeholder = self._tracker.add("video", video) self._add_placeholder(placeholder) @@ -695,10 +679,13 @@ def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int], _RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam) _VideoParser = partial(cast, ChatCompletionContentPartVideoParam) +_ContentPart: TypeAlias = Union[str, Dict[str, str], InputAudio] + # Define a mapping from part types to their corresponding parsing functions. -MM_PARSER_MAP: Dict[str, - Callable[[ChatCompletionContentPartParam], - Union[str, Dict[str,str]]]] = { +MM_PARSER_MAP: Dict[ + str, + Callable[[ChatCompletionContentPartParam], _ContentPart], +] = { "text": lambda part: _TextParser(part).get("text", ""), "image_url": @@ -715,8 +702,7 @@ def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int], def _parse_chat_message_content_mm_part( - part: ChatCompletionContentPartParam) -> Tuple[str, - Union[str, Dict[str, str]]]: + part: ChatCompletionContentPartParam) -> tuple[str, _ContentPart]: """ Parses a given multi-modal content part based on its type. @@ -783,7 +769,7 @@ def _parse_chat_message_content_parts( *, wrap_dicts: bool, ) -> List[ConversationMessage]: - content: List[Union[str, Dict[str, str]]] = [] + content = list[_ContentPart]() mm_parser = mm_tracker.create_parser() @@ -814,7 +800,7 @@ def _parse_chat_message_content_part( mm_parser: BaseMultiModalContentParser, *, wrap_dicts: bool, -) -> Optional[Union[str, Dict[str, str]]]: +) -> Optional[_ContentPart]: """Parses a single part of a conversation. If wrap_dicts is True, structured dictionary pieces for texts and images will be wrapped in dictionaries, i.e., {"type": "text", "text", ...} and @@ -823,8 +809,7 @@ def _parse_chat_message_content_part( with multimodal placeholders. """ if isinstance(part, str): # Handle plain text parts - text = _TextParser(part) - return text + return part # Handle structured dictionary parts part_type, content = _parse_chat_message_content_mm_part(part) @@ -855,7 +840,7 @@ def _parse_chat_message_content_part( return {'type': 'audio'} if wrap_dicts else None if part_type == "input_audio": - dict_content = cast(Dict[str, str], content) + dict_content = cast(InputAudio, content) mm_parser.parse_input_audio(dict_content) return {'type': 'audio'} if wrap_dicts else None diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 2e45b474237f9..befe2e6ead97e 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -16,7 +16,7 @@ from typing import AsyncIterator, Optional, Set, Tuple import uvloop -from fastapi import APIRouter, FastAPI, Request +from fastapi import APIRouter, FastAPI, HTTPException, Request from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, Response, StreamingResponse @@ -44,11 +44,15 @@ CompletionResponse, DetokenizeRequest, DetokenizeResponse, + EmbeddingChatRequest, + EmbeddingCompletionRequest, EmbeddingRequest, EmbeddingResponse, EmbeddingResponseData, ErrorResponse, LoadLoraAdapterRequest, + PoolingChatRequest, + PoolingCompletionRequest, PoolingRequest, PoolingResponse, ScoreRequest, ScoreResponse, TokenizeRequest, @@ -315,6 +319,12 @@ async def health(raw_request: Request) -> Response: return Response(status_code=200) +@router.post("/ping") +async def ping(raw_request: Request) -> Response: + """Ping check. Endpoint required for SageMaker""" + return await health(raw_request) + + @router.post("/tokenize") @with_cancellation async def tokenize(request: TokenizeRequest, raw_request: Request): @@ -488,6 +498,49 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request): return await create_score(request, raw_request) +TASK_HANDLERS = { + "generate": { + "messages": (ChatCompletionRequest, create_chat_completion), + "default": (CompletionRequest, create_completion), + }, + # should `embed` be a pooling task by default? + "embed": { + "messages": (EmbeddingChatRequest, create_embedding), + "default": (EmbeddingCompletionRequest, create_embedding), + }, + "score": { + "messages": (PoolingChatRequest, create_score), + "default": (PoolingCompletionRequest, create_score), + }, +} + + +@router.post("/invocations") +async def invocations(raw_request: Request): + """ + For SageMaker, routes requests to other handlers based on model `task`. + """ + body = await raw_request.json() + task = raw_request.app.state.task + + if task not in TASK_HANDLERS: + raise HTTPException( + status_code=400, + detail=f"Unsupported task: '{task}' for '/invocations'. " + f"Expected one of {set(TASK_HANDLERS.keys())}" + ) + + handler_config = TASK_HANDLERS[task] + if "messages" in body: + request_model, handler = handler_config["messages"] + else: + request_model, handler = handler_config["default"] + + # this is required since we lose the FastAPI automatic casting + request = request_model.model_validate(body) + return await handler(request, raw_request) + + if envs.VLLM_TORCH_PROFILER_DIR: logger.warning( "Torch Profiler is enabled in the API server. This should ONLY be " @@ -694,6 +747,7 @@ def init_app_state( chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, ) + state.task = model_config.task def create_server_socket(addr: Tuple[str, int]) -> socket.socket: diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index f3ec9d115c9ba..46346b08e99c2 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -99,6 +99,9 @@ def get_hf_processor( merged_kwargs = {**base_kwargs, **kwargs} + if isinstance(typ, type): + merged_kwargs["processor_cls"] = typ + hf_processor = cached_get_processor( self.model_config.model, trust_remote_code=self.model_config.trust_remote_code, @@ -132,10 +135,13 @@ def get_hf_processor( def call_hf_processor( self, hf_processor: ProcessorMixin, - prompt: str, - processor_data: Mapping[str, object], - inference_kwargs: Mapping[str, object], + data: Mapping[str, object], + kwargs: Mapping[str, object] = {}, ) -> BatchFeature: + """ + Call :code:`hf_processor` on the prompt :code:`data` + (text, image, audio...) with configurable options :code:`kwargs`. + """ assert callable(hf_processor) base_kwargs = self.model_config.mm_processor_kwargs @@ -144,21 +150,15 @@ def call_hf_processor( merged_kwargs = resolve_mm_processor_kwargs( base_kwargs, - inference_kwargs, + kwargs, hf_processor, requires_kw_only=False, allow_var_kwargs=True, ) try: - return hf_processor( - text=prompt, - **processor_data, - **merged_kwargs, - return_tensors="pt", - ) + return hf_processor(**data, **merged_kwargs, return_tensors="pt") except Exception as exc: - data = dict(text=prompt, **processor_data) msg = (f"Failed to apply {type(hf_processor).__name__} " f"on data={data} with kwargs={merged_kwargs}") diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index 10bec75f49fdf..606c796d503cf 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -42,12 +42,14 @@ def __init__(self, use_rms_norm: bool, rms_norm_has_weight: bool = True, rms_norm_eps: float = 1e-5, - activation="silu"): + activation="silu", + is_lora_enabled: bool = False): super().__init__() self.time_step_rank = time_step_rank self.ssm_state_size = ssm_state_size self.use_rms_norm = use_rms_norm self.activation = activation + self.is_lora_enabled = is_lora_enabled self.conv1d = ColumnParallelLinear( input_size=conv_kernel_size, @@ -63,6 +65,7 @@ def __init__(self, self.in_proj = MergedColumnParallelLinear(hidden_size, [intermediate_size] * 2, bias=use_bias) + # selective projection used to make dt, B and C input dependent self.x_proj = RowParallelLinear( intermediate_size, @@ -170,7 +173,13 @@ def forward_cuda(self, hidden_states: torch.Tensor, # 3. State Space Model sequence transformation # 3.a. input varying initialization of time_step, B and C - ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0] + + if self.is_lora_enabled: + # lora kernel requires contiguous tensor + ssm_parameters = self.x_proj( + hidden_states.transpose(-2, -1).contiguous())[0] + else: + ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0] time_step, B, C = torch.split( ssm_parameters, @@ -222,6 +231,11 @@ def forward_cuda(self, hidden_states: torch.Tensor, scan_outputs = scan_outputs.transpose(0, 1) # 4. Final linear projection - contextualized_states = self.out_proj(scan_outputs.transpose(-2, - -1))[0] + if self.is_lora_enabled: + # lora kernel requires contiguous tensor + contextualized_states = self.out_proj( + scan_outputs.transpose(-2, -1).contiguous())[0] + else: + contextualized_states = self.out_proj( + scan_outputs.transpose(-2, -1))[0] return contextualized_states diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index 73cc8ce0d2a4b..1d4e4bd52adaa 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -41,10 +41,12 @@ def process_weights_after_loading(self, layer) -> None: ) if current_platform.is_rocm(): + input_scale = getattr(layer, 'input_scale', None) + weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( weight=weight, weight_scale=max_w_scale, - input_scale=layer.input_scale) + input_scale=input_scale) if input_scale is not None: layer.input_scale = Parameter(input_scale, requires_grad=False) @@ -57,11 +59,13 @@ def process_weights_after_loading(self, layer) -> None: weight = layer.weight if current_platform.is_rocm(): + input_scale = getattr(layer, 'input_scale', None) + weight, weight_scale, input_scale = \ normalize_e4m3fn_to_e4m3fnuz( weight=weight, weight_scale=layer.weight_scale, - input_scale=layer.input_scale) + input_scale=input_scale) if input_scale is not None: layer.input_scale = Parameter(input_scale, requires_grad=False) @@ -76,7 +80,7 @@ def process_weights_after_loading(self, layer) -> None: raise ValueError(f"Unknown quantization strategy {self.strategy}") # INPUT SCALE - if self.is_static_input_scheme: + if self.is_static_input_scheme and hasattr(layer, 'input_scale'): layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False) else: diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index f2d9293b31a83..a9c1fa7221217 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -11,7 +11,8 @@ import warnings from abc import ABC, abstractmethod from contextlib import contextmanager -from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast +from typing import (Any, Callable, Dict, Generator, Iterable, List, Optional, + Tuple, cast) import gguf import huggingface_hub @@ -706,6 +707,8 @@ def __init__(self, load_config: LoadConfig): # Store all module names (from transformers) that support # BNB quantization. self.target_modules: List[str] = [] + # mapping weight names from transformers to vllm. + self.weight_mapper: Callable = lambda name: name def _get_weight_files( self, @@ -763,9 +766,12 @@ def _prepare_weights(self, model_name_or_path: str, def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool): if use_safetensors: - return safetensors_weights_iterator(hf_weights_files) + iterator = safetensors_weights_iterator(hf_weights_files) else: - return pt_weights_iterator(hf_weights_files) + iterator = pt_weights_iterator(hf_weights_files) + for name, param in iterator: + # mapping weight names from transformers to vllm. + yield self.weight_mapper(name), param def _get_quantized_weights_iterator( self, @@ -782,12 +788,12 @@ def _get_quantized_weights_iterator( try: import bitsandbytes - if bitsandbytes.__version__ < "0.44.0": + if bitsandbytes.__version__ < "0.45.0": raise ImportError("bitsandbytes version is wrong. Please " - "install bitsandbytes>=0.44.0.") + "install bitsandbytes>=0.45.0.") except ImportError as err: - raise ImportError("Please install bitsandbytes>=0.44.0 via " - "`pip install bitsandbytes>=0.44.0` to use " + raise ImportError("Please install bitsandbytes>=0.45.0 via " + "`pip install bitsandbytes>=0.45.0` to use " "bitsandbytes quantizer.") from err hf_weights_files, use_safetensors = self._prepare_weights( @@ -991,12 +997,15 @@ def _get_bnb_target_modules(self, model: nn.Module) -> None: if isinstance(module, (LinearBase, )): last_name = name.split(".")[-1] if sub_modules := inverse_stacked_mapping.get(last_name, []): - # Map vllm's names to transformers' names. + # Map vllm's names to transformers's names. for sub_name in sub_modules: self.target_modules.append( name.replace(last_name, sub_name)) - else: - self.target_modules.append(name) + # Add original module name even if the module has stacked map, + # in case model has a mixture of disk-merged and disk-splitted + # weights with same last name. + self.target_modules.append(name) + assert (self.target_modules ), "vllm currently does not support BNB quantization for" f" {type(model).__name__}" @@ -1013,6 +1022,10 @@ def _load_weights(self, model_config: ModelConfig, f"Model {type(model).__name__} does not support BitsAndBytes " "quantization yet.") + # For some models like Molmo, we need to use hf_to_vllm_mapper + # to ensure correct loading of weights. + if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None): + self.weight_mapper = lambda name: hf_to_vllm_mapper._map_name(name) # Modules whose weights might have fused on disk # we need their output_sizes to make shard in flight correctly with TP self.maybe_fused_weights_modules: Dict[str, List[int]] = {} diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 91786db5ddc96..890b5530b97d6 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -107,9 +107,11 @@ def __init__(self, layer_idx: int, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, - prefix: str = "") -> None: + is_lora_enabled: Optional[bool] = False, + **kwargs) -> None: super().__init__() self.config = config + self.is_lora_enabled = is_lora_enabled self.mamba = MambaMixer(hidden_size= config.hidden_size, ssm_state_size = config.mamba_d_state, conv_kernel_size = config.mamba_d_conv, @@ -120,7 +122,9 @@ def __init__(self, use_bias = config.mamba_proj_bias, use_rms_norm=True, rms_norm_eps=config.rms_norm_eps, - activation=config.hidden_act) + activation=config.hidden_act, + is_lora_enabled = self.is_lora_enabled + ) num_experts = config.layers_num_experts[layer_idx] ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP @@ -156,14 +160,13 @@ def forward( class JambaAttentionDecoderLayer(nn.Module): - def __init__( - self, - config: JambaConfig, - layer_idx: int, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: + def __init__(self, + config: JambaConfig, + layer_idx: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + **kwargs) -> None: super().__init__() self.hidden_size = config.hidden_size tp_size = get_tensor_model_parallel_world_size() @@ -287,17 +290,18 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): org_num_embeddings=config.vocab_size, ) + extra_kwargs = {"is_lora_enabled": bool(vllm_config.lora_config)} + def get_layer(prefix: str): layer_idx = int(prefix.rsplit(".", 1)[1]) layer_class = ALL_DECODER_LAYER_TYPES[ config.layers_block_type[layer_idx]] - return layer_class( - config, - layer_idx, - cache_config, - quant_config=quant_config, - prefix=prefix, - ) + return layer_class(config, + layer_idx, + cache_config, + quant_config=quant_config, + prefix=prefix, + **extra_kwargs) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers") @@ -371,14 +375,13 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, "k_proj", "v_proj", ], + "in_proj": ["in_proj"], } # LoRA specific attributes supported_lora_modules = [ - "qkv_proj", - "o_proj", - "embed_tokens", - "lm_head", + "qkv_proj", "o_proj", "embed_tokens", "lm_head", "up_proj", + "down_proj", "gate_proj", "out_proj", "in_proj", "x_proj" ] embedding_modules = { "embed_tokens": "input_embeddings", @@ -423,9 +426,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) if self.scheduler_config is not None and \ - not self.model_config.enforce_eager: + not self.model_config.enforce_eager: if self.scheduler_config.max_num_seqs > \ - vllm_config.compilation_config.max_capture_size: + vllm_config.compilation_config.max_capture_size: self.max_batch_size = \ vllm_config.compilation_config.max_capture_size else: @@ -446,7 +449,6 @@ def forward(self, inputs_embeds: Optional[torch.Tensor] = None, **kwargs): if self.mamba_cache is None: - num_mamba_layers = self.model_config.get_num_layers_by_block_type( self.vllm_config.parallel_config, LayerBlockType.mamba) self.mamba_cache = MambaCacheManager( diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 0662d90e79b92..0ecba5a1cae0f 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -1,5 +1,4 @@ from functools import cached_property -from types import MethodType from typing import (Iterable, List, Literal, Mapping, Optional, Protocol, Set, Tuple, TypedDict, Union) @@ -7,7 +6,7 @@ import torch.nn as nn from transformers import (BatchFeature, CLIPVisionConfig, LlavaConfig, PixtralVisionConfig, PretrainedConfig, - ProcessorMixin, SiglipVisionConfig) + SiglipVisionConfig) from transformers.models.llava import LlavaProcessor from transformers.models.pixtral import PixtralProcessor @@ -21,10 +20,12 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import NestedTensors +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalDataItems, + MultiModalFieldConfig, MultiModalInputsV2, + MultiModalKwargs, NestedTensors) from vllm.multimodal.processing import (BaseMultiModalProcessor, - MultiModalDataItems, ProcessorInputs, - PromptReplacement) + ProcessorInputs, PromptReplacement, + full_groupby_modality) from vllm.sequence import IntermediateTensors from .clip import (CLIPVisionModel, dummy_image_for_clip, @@ -116,36 +117,54 @@ def get_max_llava_image_tokens(ctx: InputContext): class LlavaMultiModalProcessor(BaseMultiModalProcessor): - def _patch_pixtral_processor(self, hf_processor: PixtralProcessor): - if getattr(hf_processor, "__is_patched__", False): - return # Already patched - - image_processor = hf_processor.image_processor # type: ignore - orig_preprocess = image_processor.preprocess + def _get_hf_processor(self) -> Union[LlavaProcessor, PixtralProcessor]: + return self.ctx.get_hf_processor((LlavaProcessor, PixtralProcessor)) - def preprocess(__self, *args, **kwargs): - hf_inputs = orig_preprocess(*args, **kwargs) - hf_inputs["is_pixtral"] = torch.tensor(True) - return hf_inputs + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> BatchFeature: + processed_outputs = super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + ) - image_processor.preprocess = MethodType(preprocess, image_processor) + # NOTE: pixel_values=None for MLlavaProcessor + pixel_values = processed_outputs.get("pixel_values") + if pixel_values is not None: + images = mm_data["images"] + assert isinstance(images, list) - hf_processor.__is_patched__ = True # type: ignore + if isinstance(self._get_hf_processor(), PixtralProcessor): + # Original output: (1, num_images, C, H, W) + # New output: (num_images, C, H, W) + assert (isinstance(pixel_values, list) + and len(pixel_values) == 1 + and isinstance(pixel_values[0], list) + and len(pixel_values[0]) == len(images)) - def _get_hf_processor(self) -> Union[LlavaProcessor, PixtralProcessor]: - hf_processor = self.ctx.get_hf_processor( - (LlavaProcessor, PixtralProcessor)) + processed_outputs["pixel_values"] = pixel_values[0] - if isinstance(hf_processor, PixtralProcessor): - self._patch_pixtral_processor(hf_processor) + return processed_outputs - return hf_processor + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + pixel_values=MultiModalFieldConfig.batched("image"), + image_embeds=MultiModalFieldConfig.batched("image"), + ) def _get_prompt_replacements( self, mm_items: MultiModalDataItems, - hf_inputs: BatchFeature, - mm_processor_kwargs: Mapping[str, object], + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, ) -> list[PromptReplacement]: hf_config = self.ctx.get_hf_config(LlavaConfig) image_token_id = hf_config.image_token_index @@ -200,7 +219,7 @@ def _get_dummy_mm_inputs( ) -> ProcessorInputs: hf_config = self.ctx.get_hf_config(LlavaConfig) vision_config = hf_config.vision_config - num_images = mm_counts["image"] + num_images = mm_counts.get("image", 0) if isinstance(vision_config, CLIPVisionConfig): data = dummy_image_for_clip(vision_config, num_images) @@ -218,7 +237,6 @@ def _get_dummy_mm_inputs( return ProcessorInputs( prompt_text=image_token * num_images, mm_data=data, - mm_processor_kwargs={}, ) @@ -379,7 +397,6 @@ def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[LlavaImageInputs]: pixel_values = kwargs.pop("pixel_values", None) - is_pixtral = kwargs.pop("is_pixtral", torch.tensor([False])) image_embeds = kwargs.pop("image_embeds", None) if pixel_values is None and image_embeds is None: @@ -390,33 +407,6 @@ def _parse_and_validate_image_input( raise ValueError("Incorrect type of pixel values. " f"Got type: {type(pixel_values)}") - assert isinstance(is_pixtral, torch.Tensor) - if is_pixtral.any(): - images = pixel_values - - def flatten_to_3d_tensors(item): - if isinstance(item, torch.Tensor): - if item.dim() >= 3: - return [t for t in item.view(-1, *item.shape[-3:])] - else: - raise ValueError( - f"Unexpected tensor dimension: {item.dim()}") - elif isinstance(item, list): - return [ - t for subitem in item - for t in flatten_to_3d_tensors(subitem) - ] - else: - raise ValueError(f"Unexpected type: {type(item)}") - - # Restructure the batched images into a list of lists of images - images = flatten_to_3d_tensors(pixel_values) - - return LlavaImagePixelInputs( - type="pixel_values", - data=images, - ) - return LlavaImagePixelInputs( type="pixel_values", data=self._validate_pixel_values( @@ -586,19 +576,71 @@ def load_weights(self, weights: Iterable[Tuple[str, class MantisMultiModalProcessor(LlavaMultiModalProcessor): - def _get_hf_processor(self) -> ProcessorMixin: - try: - from mantis.models.mllava import MLlavaProcessor - except ModuleNotFoundError as exc: - raise ModuleNotFoundError( - "You need to `pip install " - "git+https://github.com/TIGER-AI-Lab/Mantis.git` " - "to use this model") from exc - - processor = MLlavaProcessor.from_pretrained( - self.ctx.model_config.tokenizer) - assert isinstance(processor, ProcessorMixin) - return processor + def _get_hf_processor(self): + return self.ctx.get_hf_processor(LlavaProcessor) + + def apply( + self, + prompt_text: str, + mm_data: MultiModalDataDict, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> MultiModalInputsV2: + hf_config = self.ctx.get_hf_config(LlavaConfig) + image_token_id = hf_config.image_token_index + max_image_tokens = get_max_llava_image_tokens(self.ctx) + + result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs) + + mm_items = self._get_mm_items(mm_data) + mm_item_counts = mm_items.get_item_counts() + mm_kwargs = result["mm_kwargs"] + + # We reimplement the functionality of MLlavaProcessor from + # https://github.com/TIGER-AI-Lab/Mantis.git + def get_replacement_mantis(item_idx: int): + return "".join([ + f"(image {item_idx+1}: ", # 7 tokens + "" * max_image_tokens, + ")", # 3 tokens + ]) + + mantis_repls = self._bind_prompt_replacements([ + PromptReplacement( + modality="image", + target=[image_token_id] * max_image_tokens, + replacement=get_replacement_mantis, + ) + ]) + + prompt_ids, prompt_text, _ = self._apply_prompt_replacements( + result["prompt_token_ids"], + mantis_repls, + mm_item_counts, + ) + + unbound_orig_repls = self._get_prompt_replacements( + mm_items, + hf_processor_mm_kwargs, + mm_kwargs, + ) + orig_repls = self._bind_prompt_replacements(unbound_orig_repls) + + all_placeholders = self._find_placeholders(orig_repls, prompt_ids, + mm_item_counts) + assert len(all_placeholders) == mm_item_counts.get("image", 0) + + mm_placeholders = { + modality: [item.to_range() for item in items] + for modality, items in full_groupby_modality(all_placeholders) + } + + return MultiModalInputsV2( + type="multimodal", + prompt=prompt_text, + prompt_token_ids=prompt_ids, + mm_kwargs=mm_kwargs, + mm_placeholders=mm_placeholders, + ) # To use this model, please use diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 06c8d9723cd01..553bc9c28cb21 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -38,10 +38,12 @@ class MambaDecoderLayer(nn.Module): def __init__(self, config: MambaConfig, cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None) -> None: + quant_config: Optional[QuantizationConfig] = None, + is_lora_enabled: Optional[bool] = False) -> None: super().__init__() self.config = config self.is_falcon_mamba = config.model_type == "falcon_mamba" + self.is_lora_enabled = is_lora_enabled mixer_rms_eps = config.mixer_rms_eps if self.is_falcon_mamba else None self.mixer = MambaMixer(hidden_size=config.hidden_size, ssm_state_size=config.state_size, @@ -53,7 +55,8 @@ def __init__(self, use_rms_norm=self.is_falcon_mamba, rms_norm_has_weight=not self.is_falcon_mamba, rms_norm_eps=mixer_rms_eps, - activation=config.hidden_act) + activation=config.hidden_act, + is_lora_enabled=self.is_lora_enabled) self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) @@ -85,6 +88,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config + is_lora_enabled = bool(lora_config) self.config = config self.padding_idx = config.pad_token_id @@ -101,8 +105,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: MambaDecoderLayer( - config, cache_config=cache_config, quant_config=quant_config), + lambda prefix: MambaDecoderLayer(config, + cache_config=cache_config, + quant_config=quant_config, + is_lora_enabled=is_lora_enabled), prefix=f"{prefix}.layers") self.norm_f = RMSNorm(config.hidden_size, diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index 8938f62d0c494..5d52d2c3e6b48 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -461,30 +461,71 @@ def forward( return output -class MolmoMLP(nn.Module): +class SwiGLU(nn.Module): + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x, gate = x.chunk(2, dim=-1) + # Note that the order is reversed compared to + # SiluAndMul. + return x * F.silu(gate) + + +class LanuageModelMLP(nn.Module): """Molmo's LLM mlp.""" def __init__(self, config: PretrainedConfig, input_dim: Optional[int] = None, - quant_config: Optional[QuantizationConfig] = None, - proj_name: str = "gate_up_proj") -> None: + quant_config: Optional[QuantizationConfig] = None) -> None: super().__init__() self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size // 2 - # Molmo's LLM proj weights are already merged into the disk, while - # image_projector proj is separate. If the same proj_name were used, it - # would create ambiguity and make it difficult to support BNB and LoRA. - self.proj_name = proj_name - setattr( - self, proj_name, - MergedColumnParallelLinear( - input_dim or self.hidden_size, - [self.intermediate_size] * 2, - bias=False, - quant_config=quant_config, - )) + self.gate_up_proj = MergedColumnParallelLinear( + input_dim or self.hidden_size, + [self.intermediate_size] * 2, + bias=False, + quant_config=quant_config, + ) + # Activation function. + self.act_fn = SwiGLU() + # Feed-forward output projection. + self.down_proj = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + bias=False, + quant_config=quant_config, + ) + + def forward( + self, + x: torch.Tensor, + ) -> torch.Tensor: + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class ImageProjectorMLP(nn.Module): + """Molmo's image_projector mlp.""" + + def __init__( + self, + config: PretrainedConfig, + input_dim: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size // 2 + + self.merged_linear = MergedColumnParallelLinear( + input_dim or self.hidden_size, + [self.intermediate_size] * 2, + bias=False, + quant_config=quant_config, + ) # Activation function. self.act_fn = SiluAndMul() @@ -500,7 +541,7 @@ def forward( self, x: torch.Tensor, ) -> torch.Tensor: - gate_up, _ = getattr(self, self.proj_name)(x) + gate_up, _ = self.merged_linear(x) x = self.act_fn(gate_up) x, _ = self.down_proj(x) return x @@ -523,9 +564,7 @@ def __init__( prefix=f"{prefix}.self_attn") # MLP block. - self.mlp = MolmoMLP(config, - quant_config=quant_config, - proj_name="gate_up_proj") + self.mlp = LanuageModelMLP(config, quant_config=quant_config) # LayerNorm assert config.layer_norm_type == "rms" @@ -617,11 +656,10 @@ def __init__( vision_config, nlayers=len(self.vit_layers), quant_config=quant_config) - self.image_projector = MolmoMLP( + self.image_projector = ImageProjectorMLP( config, input_dim=vision_config.image_emb_dim, quant_config=quant_config, - proj_name="merged_linear", ) image_dim = vision_config.image_emb_dim * len(self.vit_layers) @@ -842,10 +880,6 @@ def load_weights(self, weights: Iterable[Tuple[str, loaded_params: Set[str] = set() for name, loaded_weight in weights: - if "gate_up_proj" in name: - up_proj, gate_proj = loaded_weight.chunk(2, dim=0) - loaded_weight = torch.cat([gate_proj, up_proj], dim=0) - if name.endswith(".bias") and name not in params_dict: continue if is_pp_missing_parameter(name, self): @@ -1157,6 +1191,12 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): }, ) + # BitandBytes specific attributes + bitsandbytes_stacked_params_mapping = { + "gate_proj": ("merged_linear", 0), + "up_proj": ("merged_linear", 1), + } + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 4e2e7f5761544..fefa9fd62d1d0 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -12,9 +12,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Iterable, Mapping, Sequence from functools import cached_property -from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, - TypedDict, Union) +from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union import torch import torch.nn as nn @@ -32,10 +32,14 @@ from vllm.model_executor.models.clip import CLIPVisionModel from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import NestedTensors +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalDataItems, + MultiModalFieldConfig, MultiModalInputsV2, + MultiModalKwargs, NestedTensors, + PlaceholderRange) from vllm.multimodal.processing import (BaseMultiModalProcessor, - MultiModalDataItems, ProcessorInputs, - PromptReplacement) + ProcessorInputs, PromptReplacement, + _BoundPromptReplacement, + _PlaceholderInfo) from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of @@ -306,11 +310,11 @@ def get_max_phi3v_image_tokens( *, num_crops: Optional[int] = None, ) -> int: - mm_processor_kwargs = {} + hf_processor_mm_kwargs = {} if num_crops: - mm_processor_kwargs["num_crops"] = num_crops + hf_processor_mm_kwargs["num_crops"] = num_crops - processor = ctx.get_hf_processor(**mm_processor_kwargs) + processor = ctx.get_hf_processor(**hf_processor_mm_kwargs) return processor.calc_num_image_tokens_from_image_size( width=MAX_IMAGE_FEATURE_SIZE_WIDTH, @@ -331,39 +335,50 @@ def _get_hf_processor( def _call_hf_processor( self, - hf_processor: ProcessorMixin, prompt: str, - processor_data: Mapping[str, object], - mm_processor_kwargs: Mapping[str, object], + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], ) -> BatchFeature: processed_outputs = super()._call_hf_processor( - hf_processor, prompt=prompt, - processor_data=processor_data, - mm_processor_kwargs=mm_processor_kwargs, + mm_data=mm_data, + mm_kwargs=mm_kwargs, ) + input_ids = processed_outputs["input_ids"] + assert isinstance(input_ids, torch.Tensor) + # Phi3v processor has inserted -1, -2 etc as placeholder in prompt_ids, # which will cause OverflowError when decoding the prompt_ids. # Therefore, we need to do an early replacement here - token_ids = processed_outputs['input_ids'] - token_ids[token_ids < 0] = _IMAGE_TOKEN_ID - processed_outputs['input_ids'] = token_ids + input_ids.masked_fill_(input_ids < 0, _IMAGE_TOKEN_ID) return processed_outputs + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + pixel_values=MultiModalFieldConfig.batched("image"), + image_sizes=MultiModalFieldConfig.batched("image"), + image_embeds=MultiModalFieldConfig.batched("image"), + ) + def _get_prompt_replacements( self, mm_items: MultiModalDataItems, - hf_inputs: BatchFeature, - mm_processor_kwargs: Mapping[str, object], + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, ) -> list[PromptReplacement]: hf_processor = self._get_hf_processor() image_tokens: list[str] = hf_processor.img_tokens # type: ignore image_processor = hf_processor.image_processor # type: ignore - mm_config = self.ctx.get_mm_config() - max_images = mm_config.limit_per_prompt.get("image", 1) + tokenizer = self._get_tokenizer() + bos_token_id = tokenizer.bos_token_id + assert isinstance(bos_token_id, int) def get_replacement_phi3v(item_idx: int): image_size = mm_items.get_image_size(item_idx) @@ -372,21 +387,44 @@ def get_replacement_phi3v(item_idx: int): height=image_size.height, ) - return [_IMAGE_TOKEN_ID] * num_tokens + return [_IMAGE_TOKEN_ID] * num_tokens + [bos_token_id] return [ PromptReplacement( modality="image", target=image_token, replacement=get_replacement_phi3v, - ) for image_token in image_tokens[:max_images] + ) for image_token in image_tokens[:len(mm_items.images)] ] + def _apply_prompt_replacements( + self, + token_ids: list[int], + prompt_repls: Sequence[_BoundPromptReplacement], + mm_item_counts: Mapping[str, int], + ) -> tuple[list[int], str, list[_PlaceholderInfo]]: + token_ids, text, placeholders = super()._apply_prompt_replacements( + token_ids=token_ids, + prompt_repls=prompt_repls, + mm_item_counts=mm_item_counts, + ) + + # Keep the behavior in line with HF processor + if text.startswith(" <|image|>"): + text = text.replace(" <|image|>", "<|image|>", 1) + token_ids = [token_ids[0], *token_ids[2:]] + placeholders = [ + _PlaceholderInfo(p.modality, p.start_idx - 1, p.replacement) + for p in placeholders + ] + + return token_ids, text, placeholders + def _get_dummy_mm_inputs( self, mm_counts: Mapping[str, int], ) -> ProcessorInputs: - num_images = mm_counts["image"] + num_images = mm_counts.get("image", 0) data = dummy_image_for_clip( CLIP_VIT_LARGE_PATCH14_336_CONFIG, @@ -401,9 +439,28 @@ def _get_dummy_mm_inputs( return ProcessorInputs( prompt_text="".join(image_tokens[:num_images]), mm_data=data, - mm_processor_kwargs={}, ) + def apply( + self, + prompt_text: str, + mm_data: MultiModalDataDict, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> MultiModalInputsV2: + result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs) + + # Only <|image|> tokens should be considered as placeholders, + # so we ignore the trailing bos_token_id + result["mm_placeholders"] = { + modality: [ + PlaceholderRange(offset=p["offset"], length=p["length"] - 1) + for p in ps + ] + for modality, ps in result["mm_placeholders"].items() + } + + return result + @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_phi3v_image_tokens) @MULTIMODAL_REGISTRY.register_processor(Phi3VMultiModalProcessor) diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 63d1374ab4092..baf955f6b515d 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -225,7 +225,7 @@ def __init__( d_model: int, n_head: int, mlp_ratio: float = 4.0, - norm_layer: Callable = nn.LayerNorm, + norm_layer: Callable[[int], nn.Module] = nn.LayerNorm, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -266,7 +266,7 @@ def __init__( layers: int, heads: int, mlp_ratio: float = 4.0, - norm_layer: Callable = nn.LayerNorm, + norm_layer: Callable[[int], nn.Module] = nn.LayerNorm, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index 6259166a7fc57..25a351bd9c656 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -26,7 +26,7 @@ import numpy as np import torch import torch.nn as nn -from transformers import BatchFeature, ProcessorMixin +from transformers import BatchFeature from transformers.models.qwen2_audio import (Qwen2AudioConfig, Qwen2AudioEncoder, Qwen2AudioProcessor) @@ -38,10 +38,10 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import NestedTensors +from vllm.multimodal.inputs import (MultiModalDataItems, MultiModalFieldConfig, + MultiModalKwargs, NestedTensors) from vllm.multimodal.processing import (BaseMultiModalProcessor, - MultiModalDataItems, ProcessorInputs, - PromptReplacement) + ProcessorInputs, PromptReplacement) from vllm.sequence import IntermediateTensors from .interfaces import SupportsMultiModal, SupportsPP @@ -73,7 +73,7 @@ def forward(self, audio_features): # From Qwen2AudioEncoder._get_feat_extract_output_lengths -def _get_feat_extract_output_lengths(input_lengths: torch.LongTensor): +def _get_feat_extract_output_lengths(input_lengths: torch.Tensor): feat_lengths = (input_lengths - 1) // 2 + 1 output_lengths = (feat_lengths - 2) // 2 + 1 return feat_lengths, output_lengths @@ -88,13 +88,18 @@ def get_max_qwen2_audio_audio_tokens(ctx: InputContext) -> int: class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor): - def _get_hf_processor(self) -> Qwen2AudioProcessor: + def _get_hf_processor( + self, + *, + # Ignored in initialization + sampling_rate: Optional[int] = None, + ) -> Qwen2AudioProcessor: return self.ctx.get_hf_processor(Qwen2AudioProcessor) def _get_feature_extractor(self) -> WhisperFeatureExtractor: return self._get_hf_processor().feature_extractor # type: ignore - def _get_processor_data( + def _get_hf_mm_data( self, mm_items: MultiModalDataItems, ) -> tuple[dict[str, Any], dict[str, Any]]: @@ -102,50 +107,61 @@ def _get_processor_data( feature_extractor = self._get_feature_extractor() mm_items.resample_audios(feature_extractor.sampling_rate) - return super()._get_processor_data(mm_items) + return super()._get_hf_mm_data(mm_items) def _call_hf_processor( self, - hf_processor: ProcessorMixin, prompt: str, - processor_data: Mapping[str, object], - mm_processor_kwargs: Mapping[str, object], + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], ) -> BatchFeature: - processor_data = dict(processor_data) - audios = processor_data.pop("audios", []) + mm_data = dict(mm_data) + audios = mm_data.pop("audios", []) if audios: - processor_data["audios"] = audios + mm_data["audios"] = audios feature_extractor = self._get_feature_extractor() - mm_processor_kwargs = dict( - **mm_processor_kwargs, + mm_kwargs = dict( + **mm_kwargs, sampling_rate=feature_extractor.sampling_rate, ) else: # NOTE: WhisperFeatureExtractor cannot handle empty list of audios pass - return super()._call_hf_processor( - hf_processor, + processed_outputs = super()._call_hf_processor( prompt=prompt, - processor_data=processor_data, - mm_processor_kwargs=mm_processor_kwargs, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + ) + + return processed_outputs + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + input_features=MultiModalFieldConfig.batched("audio"), + feature_attention_mask=MultiModalFieldConfig.batched("audio"), ) def _get_prompt_replacements( self, mm_items: MultiModalDataItems, - hf_inputs: BatchFeature, - mm_processor_kwargs: Mapping[str, object], + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, ) -> list[PromptReplacement]: hf_config = self.ctx.get_hf_config(Qwen2AudioConfig) placeholder = hf_config.audio_token_index - feature_attention_mask = hf_inputs.get("feature_attention_mask") + feature_attention_mask = out_mm_kwargs.get("feature_attention_mask") if feature_attention_mask is None: audio_output_lengths = [] else: + assert isinstance(feature_attention_mask, torch.Tensor) _, audio_output_lengths = _get_feat_extract_output_lengths( feature_attention_mask.sum(-1)) @@ -168,14 +184,13 @@ def _get_dummy_mm_inputs( sampling_rate = feature_extractor.sampling_rate audio_len = feature_extractor.chunk_length * sampling_rate - audio_count = mm_counts["audio"] + audio_count = mm_counts.get("audio", 0) audio = np.zeros(audio_len) data = {"audio": [audio] * audio_count} return ProcessorInputs( prompt_text="<|AUDIO|>" * audio_count, mm_data=data, - mm_processor_kwargs={}, ) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index fb97eb1916002..574845ef5a525 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -22,9 +22,10 @@ # limitations under the License. """Inference-only Qwen2-VL model compatible with HuggingFace weights.""" from functools import cached_property, partial -from typing import (Any, Iterable, List, Literal, Mapping, Optional, Set, - Tuple, Type, TypedDict, Union) +from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional, + Set, Tuple, Type, TypedDict, Union) +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F @@ -54,10 +55,11 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import MultiModalDataDict, NestedTensors +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalDataItems, + MultiModalFieldConfig, MultiModalKwargs, + NestedTensors) from vllm.multimodal.processing import (BaseMultiModalProcessor, - MultiModalDataItems, ProcessorInputs, - PromptReplacement) + ProcessorInputs, PromptReplacement) from vllm.platforms import _Backend from vllm.sequence import IntermediateTensors from vllm.transformers_utils.config import uses_mrope @@ -229,9 +231,9 @@ class Qwen2VisionAttention(nn.Module): def __init__( self, - embed_dim: Optional[int] = None, - num_heads: Optional[int] = None, - projection_size: Optional[int] = None, + embed_dim: int, + num_heads: int, + projection_size: int, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: @@ -264,7 +266,7 @@ def forward( self, x: torch.Tensor, cu_seqlens: torch.Tensor, - rotary_pos_emb: torch.Tensor = None, + rotary_pos_emb: torch.Tensor, ) -> torch.Tensor: # [s, b, c] --> [s, b, head * 3 * head_dim] x, _ = self.qkv(x) @@ -347,7 +349,7 @@ def __init__( num_heads: int, mlp_ratio: float, act_layer: Type[nn.Module] = QuickGELU, - norm_layer: Type[nn.Module] = None, + norm_layer: Optional[Callable[[int], nn.Module]] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: @@ -384,7 +386,7 @@ def __init__( self, patch_size: int = 14, temporal_patch_size: int = 2, - in_chans: int = 3, + in_channels: int = 3, embed_dim: int = 1152, ) -> None: super().__init__() @@ -392,8 +394,8 @@ def __init__( self.temporal_patch_size = temporal_patch_size self.embed_dim = embed_dim - kernel_size = [temporal_patch_size, patch_size, patch_size] - self.proj = nn.Conv3d(in_chans, + kernel_size = (temporal_patch_size, patch_size, patch_size) + self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, @@ -413,7 +415,7 @@ def __init__( self, d_model: int, context_dim: int, - norm_layer: Type[nn.Module] = None, + norm_layer: Optional[Callable[[int], nn.Module]] = None, spatial_merge_size: int = 2, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", @@ -489,15 +491,15 @@ def __init__( ) -> None: super().__init__() - patch_size: int = vision_config.patch_size - temporal_patch_size: int = vision_config.temporal_patch_size - spatial_merge_size: int = vision_config.spatial_merge_size - in_chans: int = vision_config.in_chans - hidden_size: int = vision_config.hidden_size - embed_dim: int = vision_config.embed_dim - depth: int = vision_config.depth - num_heads: int = vision_config.num_heads - mlp_ratio: float = vision_config.mlp_ratio + patch_size = vision_config.patch_size + temporal_patch_size = vision_config.temporal_patch_size + spatial_merge_size = vision_config.spatial_merge_size + in_channels = vision_config.in_channels + hidden_size = vision_config.hidden_size + embed_dim = vision_config.embed_dim + depth = vision_config.depth + num_heads = vision_config.num_heads + mlp_ratio = vision_config.mlp_ratio self.spatial_merge_size = spatial_merge_size self.num_heads = num_heads @@ -506,7 +508,7 @@ def __init__( self.patch_embed = Qwen2VisionPatchEmbed( patch_size=patch_size, temporal_patch_size=temporal_patch_size, - in_chans=in_chans, + in_channels=in_channels, embed_dim=embed_dim, ) @@ -733,8 +735,12 @@ def from_dict(data: MultiModalDataDict) -> "MultiModalDataItems": if k == "video": # Special case since even a single item can be a list multi_data[k] = ( # type: ignore[index] - v if (isinstance(v, (dict, torch.Tensor)) # type: ignore[assignment] - or is_list_of(v, list)) else [v] + v if ( + isinstance(v, (dict, torch.Tensor)) # type: ignore[assignment] + or is_list_of(v, list) + or isinstance(v[0], (np.ndarray, torch.Tensor)) + and v[0].ndim == 4 + ) else [v] ) elif k in ("image", "audio"): multi_data[k] = ( # type: ignore[index] @@ -754,6 +760,12 @@ def get_item_counts(self) -> Mapping[str, int]: for m, items in self.items() } + def has_embedding_inputs(self) -> bool: + return any( + isinstance(items, dict) or any( + isinstance(item, torch.Tensor) for item in items) + for items in self.values()) + class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor): @@ -784,7 +796,7 @@ def _get_hf_processor( return hf_processor - def _get_processor_data( + def _get_hf_mm_data( self, mm_items: MultiModalDataItems, ) -> tuple[dict[str, Any], dict[str, Any]]: @@ -805,7 +817,7 @@ def _get_processor_data( and v[0].ndim == 2): # Pass through embedding inputs (multi) passthrough_data[f"{k}_embeds"] = v - else: + elif len(v) > 0: # Map keys to plural form, e.g.: image -> images processor_data[f"{k}s"] = v else: @@ -816,8 +828,8 @@ def _get_processor_data( def _get_prompt_replacements( self, mm_items: MultiModalDataItems, - hf_inputs: BatchFeature, - mm_processor_kwargs: Mapping[str, object], + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, ) -> list[PromptReplacement]: hf_processor = self._get_hf_processor() image_processor = _get_image_processor(hf_processor) @@ -831,7 +843,9 @@ def _get_prompt_replacements( merge_length = image_processor.merge_size**2 def get_replacement_qwen2vl(item_idx: int, modality: str): - grid_thw = hf_inputs[f"{modality}_grid_thw"][item_idx] + grid_thw = out_mm_kwargs[f"{modality}_grid_thw"][item_idx] + assert isinstance(grid_thw, torch.Tensor) + num_tokens = grid_thw.prod() // merge_length return placeholder[modality] * num_tokens @@ -844,11 +858,40 @@ def get_replacement_qwen2vl(item_idx: int, modality: str): ) for modality in ("image", "video") ] + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3))) + image_slice_idxs = [0] + image_grid_thw.prod(-1).cumsum_(0).tolist() + image_slices = [ + slice(image_slice_idxs[i], image_slice_idxs[i + 1]) + for i in range(len(image_grid_thw)) + ] + + video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3))) + video_slice_idxs = [0] + video_grid_thw.prod(-1).cumsum_(0).tolist() + video_slices = [ + slice(video_slice_idxs[i], video_slice_idxs[i + 1]) + for i in range(len(video_grid_thw)) + ] + + return dict( + pixel_values=MultiModalFieldConfig.flat("image", image_slices), + image_embeds=MultiModalFieldConfig.flat("image", image_slices), + image_grid_thw=MultiModalFieldConfig.batched("image"), + pixel_values_videos=MultiModalFieldConfig.flat( + "video", video_slices), + video_embeds=MultiModalFieldConfig.flat("video", video_slices), + video_grid_thw=MultiModalFieldConfig.batched("video"), + ) + def _get_dummy_mm_inputs( self, mm_counts: Mapping[str, int], ) -> ProcessorInputs: - num_images = mm_counts["image"] + num_images = mm_counts.get("image", 0) hf_processor = self._get_hf_processor() image_token: str = hf_processor.image_token image_processor = _get_image_processor(hf_processor) @@ -869,7 +912,6 @@ def _get_dummy_mm_inputs( return ProcessorInputs( prompt_text=image_token * num_images, mm_data=data, - mm_processor_kwargs={}, ) @@ -950,9 +992,7 @@ def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig): return None return quant_config - def _validate_and_reshape_mm_tensor(self, - mm_input: Union[torch.Tensor, - List[torch.Tensor]], + def _validate_and_reshape_mm_tensor(self, mm_input: object, name: str) -> torch.Tensor: if not isinstance(mm_input, (torch.Tensor, list)): raise ValueError(f"Incorrect type of {name}. " @@ -962,7 +1002,8 @@ def _validate_and_reshape_mm_tensor(self, return mm_input if mm_input.ndim != 3: raise ValueError(f"{name} should be 2D or batched 3D tensor. " - f"Got ndim: {mm_input.ndim}") + f"Got ndim: {mm_input.ndim} " + f"(shape={mm_input.shape})") return torch.concat(list(mm_input)) else: return torch.concat(mm_input) diff --git a/vllm/model_executor/models/telechat2.py b/vllm/model_executor/models/telechat2.py index 28c37bb96612c..02ca7fe08e556 100644 --- a/vllm/model_executor/models/telechat2.py +++ b/vllm/model_executor/models/telechat2.py @@ -31,19 +31,6 @@ class TeleChat2Model(LlamaModel): - hf_to_vllm_mapper = WeightsMapper( - orig_to_new_prefix={ - "transformer.": "model.", - }, - orig_to_new_substr={ - ".h.": ".layers.", - ".self_attention.": ".self_attn.", - ".word_embeddings.": ".embed_tokens.", - ".dense.": ".o_proj.", - ".ln_f.": ".norm.", - }, - ) - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # 1. Initialize the LlamaModel with bias vllm_config.model_config.hf_config.bias = True @@ -118,6 +105,19 @@ def load_weights(self, weights: Iterable[Tuple[str, class TeleChat2ForCausalLM(LlamaForCausalLM): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "transformer.": "model.", + }, + orig_to_new_substr={ + ".h.": ".layers.", + ".self_attention.": ".self_attn.", + ".word_embeddings.": ".embed_tokens.", + ".dense.": ".o_proj.", + ".ln_f.": ".norm.", + }, + ) + def _init_model(self, vllm_config: VllmConfig, prefix: str = ""): return TeleChat2Model(vllm_config=vllm_config, prefix=prefix) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 509ad9e580ddf..7b4aeeec5f403 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -23,10 +23,11 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.model_loader.loader import DefaultModelLoader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalDataItems, MultiModalFieldConfig, + MultiModalKwargs, NestedTensors) from vllm.multimodal.processing import (BaseMultiModalProcessor, - MultiModalDataItems, ProcessorInputs, - PromptReplacement) + ProcessorInputs, PromptReplacement) from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.ultravox import UltravoxConfig from vllm.utils import is_list_of @@ -72,11 +73,19 @@ def get_ultravox_max_audio_tokens(ctx: InputContext): class UltravoxMultiModalProcessor(BaseMultiModalProcessor): + def _get_hf_processor( + self, + *, + # Ignored in initialization + sampling_rate: Optional[int] = None, + ) -> ProcessorMixin: + return self.ctx.get_hf_processor() + def _get_feature_extractor(self) -> WhisperFeatureExtractor: hf_processor = self._get_hf_processor() return hf_processor.audio_processor.feature_extractor # type: ignore - def _get_processor_data( + def _get_hf_mm_data( self, mm_items: MultiModalDataItems, ) -> tuple[dict[str, Any], dict[str, Any]]: @@ -84,33 +93,41 @@ def _get_processor_data( feature_extractor = self._get_feature_extractor() mm_items.resample_audios(feature_extractor.sampling_rate) - return super()._get_processor_data(mm_items) + return super()._get_hf_mm_data(mm_items) def _call_hf_processor( self, - hf_processor: ProcessorMixin, prompt: str, - processor_data: Mapping[str, object], - mm_processor_kwargs: Mapping[str, object], + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], ) -> BatchFeature: - processor_data = dict(processor_data) - audios = processor_data.pop("audios", []) + # Text-only input not supported in composite processor + if not mm_data: + tokenizer = self._get_tokenizer() + + prompt_ids = tokenizer.encode( + prompt, + add_special_tokens=False, # type: ignore + ) + return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") + + mm_data = dict(mm_data) + audios = mm_data.pop("audios", []) if not audios: return super()._call_hf_processor( - hf_processor, prompt=prompt, - processor_data=processor_data, - mm_processor_kwargs=mm_processor_kwargs, + mm_data=mm_data, + mm_kwargs=mm_kwargs, ) feature_extractor = self._get_feature_extractor() - mm_processor_kwargs = dict( - **mm_processor_kwargs, + mm_kwargs = dict( + **mm_kwargs, sampling_rate=feature_extractor.sampling_rate, ) - # Already resampled by _get_processor_data + # Already resampled by _get_hf_mm_data assert is_list_of(audios, np.ndarray) # Ultravox processor doesn't support multiple inputs, @@ -119,13 +136,12 @@ def _call_hf_processor( shared_outputs = {} for audio in audios: # NOTE: Ultravox processor accepts "audio" instead of "audios" - item_processor_data = dict(**processor_data, audio=audio) + item_processor_data = dict(**mm_data, audio=audio) item_outputs = super()._call_hf_processor( - hf_processor, prompt=prompt, - processor_data=item_processor_data, - mm_processor_kwargs=mm_processor_kwargs, + mm_data=item_processor_data, + mm_kwargs=mm_kwargs, ) audio_features.append(item_outputs.pop("audio_values")[0]) @@ -139,17 +155,28 @@ def _call_hf_processor( ) return BatchFeature(combined_outputs) + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + audio_features=MultiModalFieldConfig.batched("audio"), + audio_token_len=MultiModalFieldConfig.batched("audio"), + audio_embeds=MultiModalFieldConfig.batched("audio"), + ) + def _get_prompt_replacements( self, mm_items: MultiModalDataItems, - hf_inputs: BatchFeature, - mm_processor_kwargs: Mapping[str, object], + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, ) -> list[PromptReplacement]: hf_processor = self._get_hf_processor() placeholder = hf_processor.audio_token_replacement # type: ignore def get_replacement_ultravox(item_idx: int): - audio_token_len = hf_inputs["audio_token_len"][item_idx] + audio_token_len = out_mm_kwargs["audio_token_len"][item_idx] return placeholder * audio_token_len return [ @@ -168,14 +195,13 @@ def _get_dummy_mm_inputs( sampling_rate = feature_extractor.sampling_rate audio_len = feature_extractor.chunk_length * sampling_rate - audio_count = mm_counts["audio"] + audio_count = mm_counts.get("audio", 0) audio = np.zeros(audio_len) data = {"audio": [audio] * audio_count} return ProcessorInputs( prompt_text="<|audio|>" * audio_count, mm_data=data, - mm_processor_kwargs={}, ) diff --git a/vllm/multimodal/audio.py b/vllm/multimodal/audio.py index ed3bb82bf0aaa..3e09ef1fcbb56 100644 --- a/vllm/multimodal/audio.py +++ b/vllm/multimodal/audio.py @@ -1,10 +1,14 @@ +import base64 +from io import BytesIO +from pathlib import Path + import numpy as np import numpy.typing as npt from vllm.inputs.registry import InputContext from vllm.utils import PlaceholderModule -from .base import MultiModalPlugin +from .base import MediaIO, MultiModalPlugin from .inputs import AudioItem, MultiModalData, MultiModalKwargs try: @@ -12,6 +16,11 @@ except ImportError: librosa = PlaceholderModule("librosa") # type: ignore[assignment] +try: + import soundfile +except ImportError: + soundfile = PlaceholderModule("soundfile") # type: ignore[assignment] + class AudioPlugin(MultiModalPlugin): """Plugin for audio data.""" @@ -39,3 +48,28 @@ def resample_audio( target_sr: float, ) -> npt.NDArray[np.floating]: return librosa.resample(audio, orig_sr=orig_sr, target_sr=target_sr) + + +class AudioMediaIO(MediaIO[tuple[npt.NDArray, float]]): + + def load_bytes(self, data: bytes) -> tuple[npt.NDArray, float]: + return librosa.load(BytesIO(data), sr=None) + + def load_base64( + self, + media_type: str, + data: str, + ) -> tuple[npt.NDArray, float]: + return self.load_bytes(base64.b64decode(data)) + + def load_file(self, filepath: Path) -> tuple[npt.NDArray, float]: + return librosa.load(filepath, sr=None) + + def encode_base64(self, media: tuple[npt.NDArray, float]) -> str: + audio, sr = media + + with BytesIO() as buffer: + soundfile.write(buffer, audio, sr, format="WAV") + data = buffer.getvalue() + + return base64.b64encode(data).decode('utf-8') diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index 1e5a46946c6c0..cdda6f8052794 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from collections import defaultdict -from typing import (TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple, +from pathlib import Path +from typing import (TYPE_CHECKING, Any, Callable, Generic, NamedTuple, Optional, Sequence, Tuple, Type, TypeVar, Union) from torch import nn @@ -118,7 +119,7 @@ def map_input( self, model_config: "ModelConfig", data: MultiModalData[Any], - mm_processor_kwargs: Optional[Dict[str, Any]], + mm_processor_kwargs: Optional[dict[str, Any]], ) -> MultiModalKwargs: """ Transform the data into a dictionary of model inputs using the @@ -254,10 +255,10 @@ class MultiModalPlaceholderMap: """ class IndexMap(NamedTuple): - src: List[int] - dest: List[int] + src: list[int] + dest: list[int] - src_ranges: List[range] + src_ranges: list[range] """ The indices of the multi-modal embeddings that will replace the corresponding placeholder embeddings pointed to by ``dest_ranges``. @@ -268,7 +269,7 @@ class IndexMap(NamedTuple): The total number of flattened multi-modal embeddings. """ - dest_ranges: List[range] + dest_ranges: list[range] """ The indices of the placeholder embeddings that will be replaced by the multimodal embeddings. @@ -288,7 +289,7 @@ def __init__(self): @classmethod def from_seq_group( cls, seq_group: "SequenceGroupMetadata", positions: range - ) -> Tuple[Optional[MultiModalDataDict], Dict[str, + ) -> Tuple[Optional[MultiModalDataDict], dict[str, "MultiModalPlaceholderMap"]]: """ Returns the multi-modal items that intersect with the portion of a @@ -296,35 +297,37 @@ def from_seq_group( ``MultiModalPlaceholderMap`` that relates the multi-modal embedding vectors to their corresponding placeholders. - Consider the following scenarios: + Examples: - Prompt: |AAAA BBBB What's in these images?| - Positions: |.................................| + .. code-block:: - images = [A, B] - src_ranges = [(0, 4), (4, 8)] - dest_ranges = [(0, 4), (5, 9)] + Prompt: |AAAA BBBB What's in these images?| + Positions: |.................................| - Prompt: |AAAA BBBB What's in these images?| - Positions: | ..... | + images = [A, B] + src_ranges = [(0, 4), (4, 8)] + dest_ranges = [(0, 4), (5, 9)] - images = [A, B] - src_ranges = [(2, 4), (4, 6)] - dest_ranges = [(0, 2), (3, 5)] + Prompt: |AAAA BBBB What's in these images?| + Positions: | ..... | - Prompt: |AAAA BBBB What's in these images?| - Positions: | ......... | + images = [A, B] + src_ranges = [(2, 4), (4, 6)] + dest_ranges = [(0, 2), (3, 5)] - images = [B] - src_ranges = [(0, 4)] - dest_ranges = [(0, 4)] + Prompt: |AAAA BBBB What's in these images?| + Positions: | ......... | - Prompt: |AAAA BBBB What's in these images?| - Positions: | .......................| + images = [B] + src_ranges = [(0, 4)] + dest_ranges = [(0, 4)] - images = [] - src_ranges = [] - dest_ranges = [] + Prompt: |AAAA BBBB What's in these images?| + Positions: | .......................| + + images = [] + src_ranges = [] + dest_ranges = [] """ seq_mm_data = seq_group.multi_modal_data seq_mm_placeholders = seq_group.multi_modal_placeholders @@ -376,9 +379,9 @@ def from_seq_group( def append_items_from_seq_group( self, positions: range, - multi_modal_items: List[_T], + multi_modal_items: list[_T], multi_modal_placeholders: Sequence[PlaceholderRange], - ) -> List[_T]: + ) -> list[_T]: """ Adds the multi-modal items that intersect ```positions`` to this placeholder map and returns the intersecting items. @@ -454,3 +457,22 @@ def index_map(self) -> "IndexMap": return MultiModalPlaceholderMap.IndexMap(src=src_indices, dest=dest_indices) + + +class MediaIO(ABC, Generic[_T]): + + @abstractmethod + def load_bytes(self, data: bytes) -> _T: + raise NotImplementedError + + @abstractmethod + def load_base64(self, media_type: str, data: str) -> _T: + """ + List of media types: + https://www.iana.org/assignments/media-types/media-types.xhtml + """ + raise NotImplementedError + + @abstractmethod + def load_file(self, filepath: Path) -> _T: + raise NotImplementedError diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py index c705e1a3d1554..14c79dfadec0c 100644 --- a/vllm/multimodal/image.py +++ b/vllm/multimodal/image.py @@ -1,4 +1,7 @@ +import base64 from functools import lru_cache +from io import BytesIO +from pathlib import Path from typing import TYPE_CHECKING, Any, Dict, Optional import torch @@ -9,7 +12,7 @@ from vllm.transformers_utils.processor import get_image_processor from vllm.utils import is_list_of -from .base import MultiModalPlugin +from .base import MediaIO, MultiModalPlugin from .inputs import ImageItem, MultiModalData, MultiModalKwargs if TYPE_CHECKING: @@ -96,3 +99,39 @@ def rescale_image_size(image: Image.Image, if transpose >= 0: image = image.transpose(Image.Transpose(transpose)) return image + + +class ImageMediaIO(MediaIO[Image.Image]): + + def __init__(self, *, image_mode: str = "RGB") -> None: + super().__init__() + + self.image_mode = image_mode + + def load_bytes(self, data: bytes) -> Image.Image: + image = Image.open(BytesIO(data)) + image.load() + return image.convert(self.image_mode) + + def load_base64(self, media_type: str, data: str) -> Image.Image: + return self.load_bytes(base64.b64decode(data)) + + def load_file(self, filepath: Path) -> Image.Image: + image = Image.open(filepath) + image.load() + return image.convert(self.image_mode) + + def encode_base64( + self, + media: Image.Image, + *, + image_format: str = "JPEG", + ) -> str: + image = media + + with BytesIO() as buffer: + image = image.convert(self.image_mode) + image.save(buffer, image_format) + data = buffer.getvalue() + + return base64.b64encode(data).decode('utf-8') diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 9ecae2c1ca2bf..1fbda6e0b8750 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -1,12 +1,16 @@ +from abc import ABC, abstractmethod from collections import UserDict, defaultdict -from typing import (Any, Dict, List, Literal, Mapping, Sequence, Tuple, - TypedDict, TypeVar, Union, cast, final) +from collections.abc import Mapping, Sequence +from dataclasses import dataclass +from typing import (Any, Literal, NamedTuple, TypedDict, TypeVar, Union, cast, + final) import numpy as np import torch import torch.types from PIL.Image import Image -from typing_extensions import NotRequired, TypeAlias +from transformers import BatchFeature +from typing_extensions import NotRequired, TypeAlias, assert_never from vllm.utils import JSONTree, is_list_of, json_map_leaves @@ -44,7 +48,7 @@ """ # yapf: enable -MultiModalData: TypeAlias = Union[_T, List[_T]] +MultiModalData: TypeAlias = Union[_T, list[_T]] """ Either a single data item, or a list of data items. @@ -79,13 +83,135 @@ class MultiModalDataBuiltins(TypedDict, total=False): """ +class ImageSize(NamedTuple): + width: int + height: int + + +class MultiModalDataItems(UserDict[str, list[Any]]): + """ + As :class:`MultiModalDataDict`, but normalized such that each entry + corresponds to a list. + """ + + @staticmethod + def from_dict(data: MultiModalDataDict) -> "MultiModalDataItems": + """ + Normalize :class:`MultiModalDataDict` to :class:`MultiModalDataItems`. + """ + multi_data = MultiModalDataItems() + + for k, v in data.items(): + # TODO: Make a separate modality for embedding inputs + # to avoid confusion + # yapf: disable + if k == "video": + # Special case since even a single item can be a list + multi_data[k] = ( # type: ignore[index] + v if ( + isinstance(v, torch.Tensor) + or is_list_of(v, list) + or isinstance(v[0], (np.ndarray, torch.Tensor)) + and v[0].ndim == 4 + ) else [v] + ) + elif k in ("image", "audio"): + multi_data[k] = ( # type: ignore[index] + v if isinstance(v, (torch.Tensor, list)) else [v] + ) + else: + multi_data[k] = v if isinstance(v, list) else [v] # type: ignore[index] + # yapf: enable + + return multi_data + + # NOTE: When a field (e.g. `images`) doesn't exist, directly appending to + # `self.images` doesn't update this dictionary, which may be confusing + # We annotate the getter methods as `Sequence` to prevent others from + # trying to update the list in this way + @property + def images(self) -> Sequence[ImageItem]: + return self.get("image", []) + + @property + def videos(self) -> Sequence[VideoItem]: + return self.get("video", []) + + @property + def audios(self) -> Sequence[AudioItem]: + return self.get("audio", []) + + def get_item_counts(self) -> Mapping[str, int]: + return {m: len(items) for m, items in self.items()} + + def has_embedding_inputs(self) -> bool: + return any( + any(isinstance(item, torch.Tensor) for item in items) + for items in self.values()) + + def get_image_size(self, item_idx: int) -> ImageSize: + image = self.images[item_idx] + + if isinstance(image, Image): + return ImageSize(*image.size) + if isinstance(image, (np.ndarray, torch.Tensor)): + _, h, w = image.shape + return ImageSize(w, h) + + assert_never(image) + + def get_audio_with_sr( + self, + item_idx: int, + *, + default_sr: float, + ) -> tuple[np.ndarray, float]: + audio = self.audios[item_idx] + + if isinstance(audio, tuple): + return audio + if isinstance(audio, list): + return np.array(audio), default_sr + if isinstance(audio, np.ndarray): + return audio, default_sr + + assert_never(audio) + + def resample_audios(self, new_sr: float, *, drop_sr: bool = True) -> None: + """ + If :code:`drop_sr=True`, the audio items in this dictionary are updated + to be NumPy arrays which implicitly means that their sampling rate is + the same as the model's expected sampling rate; otherwise, they remain + as :code:`(audio, new_sr)` tuples. + """ + # Avoid circular import + from .audio import resample_audio + + if not self.audios: + return + + new_audios = [] + for item_idx in range(len(self.audios)): + audio, sr = self.get_audio_with_sr(item_idx, default_sr=new_sr) + audio = resample_audio(audio, orig_sr=sr, target_sr=new_sr) + + new_audios.append(audio if drop_sr else (audio, new_sr)) + + self["audio"] = new_audios + + class PlaceholderRange(TypedDict): """ Placeholder location information for multi-modal data. - For example: - Prompt: AAAA BBBB What is in these images? + Example: + + Prompt: :code:`AAAA BBBB What is in these images?` + Images A and B will have: + + .. code-block:: + A: { "offset": 0, "length": 4 } B: { "offset": 5, "length": 4 } """ @@ -97,25 +223,256 @@ class PlaceholderRange(TypedDict): """The length of the placeholder.""" -NestedTensors = Union[List["NestedTensors"], List[torch.Tensor], torch.Tensor, - Tuple[torch.Tensor, ...]] +NestedTensors = Union[list["NestedTensors"], list[torch.Tensor], torch.Tensor, + tuple[torch.Tensor, ...]] """ Uses a list instead of a tensor if the dimensions of each element do not match. """ -BatchedTensorInputs: TypeAlias = Dict[str, NestedTensors] + +def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool: + """Equality check between :data:`NestedTensors` objects.""" + if isinstance(a, torch.Tensor): + return isinstance(b, torch.Tensor) and bool((a == b).all().item()) + elif isinstance(b, torch.Tensor): + return isinstance(a, torch.Tensor) and bool((b == a).all().item()) + + if isinstance(a, list): + return (isinstance(b, list) + and all(nested_tensors_equal(a_, b_) for a_, b_ in zip(a, b))) + if isinstance(b, list): + return (isinstance(a, list) + and all(nested_tensors_equal(b_, a_) for b_, a_ in zip(b, a))) + + # Both a and b are scalars + return a == b + + +BatchedTensorInputs: TypeAlias = Mapping[str, NestedTensors] """ A dictionary containing nested tensors which have been batched via :meth:`MultiModalKwargs.batch`. """ +@dataclass(frozen=True) +class MultiModalFieldItem: + """ + Contains metadata and data in :class:`MultiModalKwargs` + corresponding to a data item in :class:`MultiModalDataItems`. + """ + field: "BaseMultiModalField" + data: NestedTensors + + def __eq__(self, other: object) -> bool: + if not isinstance(other, self.__class__): + return False + + return (self.field == other.field + and nested_tensors_equal(self.data, other.data)) + + +@dataclass(frozen=True) +class BaseMultiModalField(ABC): + """Abstract base class for a field in :class:`MultiModalKwargs`.""" + key: str + modality: str + + @abstractmethod + def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors: + raise NotImplementedError + + def _build_item(self, data: NestedTensors) -> MultiModalFieldItem: + return MultiModalFieldItem(self, data) + + def reduce(self, batch: list[MultiModalFieldItem]) -> MultiModalFieldItem: + """Merge multiple instances of :class:`MultiModalFieldItem` together.""" + fields = [item.field for item in batch] + if len(set(fields)) > 1: + raise ValueError(f"Cannot merge different {fields=}") + + data = self._reduce_data([item.data for item in batch]) + + return self._build_item(data) + + +@dataclass(frozen=True) +class MultiModalBatchedField(BaseMultiModalField): + """ + A :class:`BaseMultiModalField` implementation where an item is obtained by + directly indexing into the first dimension of the underlying data. + """ + + def build_items(self, batch: NestedTensors) -> list[MultiModalFieldItem]: + return [self._build_item(item) for item in batch] + + def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors: + if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"): + first_shape = batch[0].shape + if all(item.shape == first_shape for item in batch): + return torch.stack(batch) + + return batch + + +@dataclass(frozen=True) +class MultiModalFlatField(BaseMultiModalField): + """ + A :class:`BaseMultiModalField` implementation where an item is obtained by + slicing along the first dimension of the underlying data. + """ + + def build_items( + self, + batch: NestedTensors, + slices: Sequence[slice], + ) -> list[MultiModalFieldItem]: + return [self._build_item(batch[slice_]) for slice_ in slices] + + def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors: + if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"): + first_shape = batch[0].shape + if all(item.shape[1:] == first_shape[1:] for item in batch): + return torch.concat(batch) + + return [elem for item in batch for elem in item] + + +class MultiModalFieldConfig: + + @staticmethod + def batched(modality: str): + return MultiModalFieldConfig( + field_cls=MultiModalBatchedField, + modality=modality, + ) + + @staticmethod + def flat(modality: str, slices: Sequence[slice]): + return MultiModalFieldConfig( + field_cls=MultiModalFlatField, + modality=modality, + slices=slices, + ) + + def __init__( + self, + field_cls: type[BaseMultiModalField], + modality: str, + **field_config: Any, + ) -> None: + super().__init__() + + self._field_cls = field_cls + self._modality = modality + self._field_config = field_config + + def build_items( + self, + key: str, + batch: NestedTensors, + ) -> list[MultiModalFieldItem]: + field = self._field_cls(key=key, modality=self._modality) + return field.build_items(batch, **self._field_config) # type: ignore + + class MultiModalKwargs(UserDict[str, NestedTensors]): """ A dictionary that represents the keyword arguments to :meth:`~torch.nn.Module.forward`. + + The metadata :code:`items_by_key` defines how to split batched keyword + arguments corresponding to each data item in :class:`MultiModalDataItems`: + + - For a keyword argument, we can access the :code:`i` th item in the batch + via :code:`items_by_key[key][i]`. + - We can gather the keyword arguments belonging to a modality by finding + the keys with items that belong to that modality, then accessing + the :code:`i` th item in the batch for each such key. + + Example: + + .. code-block:: python + + # All items belong to the "image" modality + items_by_key={ + "pixel_values": [a, b, c, d], # "image" modality + "image_grid_thw": [e, f, g, h], # "image" modality + "pixel_values_video": [h, i, j], # "video" modality + "video_grid_thw": [k, l, m], # "video" modality + } + + - The keyword arguments belonging to the first image are + :code:`{"pixel_values": a, "image_grid_thw": e}`. + - The keyword arguments belonging to the second video are + :code:`{"pixel_values_video": i, "video_grid_thw": l}`. """ + @staticmethod + def from_hf_inputs( + hf_inputs: BatchFeature, + config_by_key: Mapping[str, MultiModalFieldConfig], + *, + enable_sanity_checks: bool = False, + ): + # NOTE: This skips fields in `hf_inputs` that are not in `config_by_key` + # We assume that those fields are not used in vLLM + items_by_key = { + key: config.build_items(key, batch) + for key, config in config_by_key.items() + if (batch := hf_inputs.get(key)) is not None + } + + return MultiModalKwargs.from_items_by_key( + items_by_key, + enable_sanity_checks=enable_sanity_checks, + ) + + @staticmethod + def from_items_by_key( + items_by_key: Mapping[str, list[MultiModalFieldItem]], + *, + enable_sanity_checks: bool = False, + ) -> "MultiModalKwargs": + data = { + key: items[0].field.reduce(items).data + for key, items in items_by_key.items() + } + + return MultiModalKwargs(data, + items_by_key=items_by_key, + enable_sanity_checks=enable_sanity_checks) + + def __init__( + self, + data: Mapping[str, NestedTensors], + *, + items_by_key: Mapping[str, list[MultiModalFieldItem]] = {}, + enable_sanity_checks: bool = False, + ) -> None: + super().__init__(data) + + # Shallow copy to avoid footgun in case a defaultdict is passed in + self._items_by_key = dict(items_by_key) + + keys_by_modality = defaultdict[str, set[str]](set) + for key, items in items_by_key.items(): + for item in items: + keys_by_modality[item.field.modality].add(key) + + self._keys_by_modality = dict(keys_by_modality) + + if enable_sanity_checks: + for modality, keys in keys_by_modality.items(): + items_in_modality = {k: items_by_key[k] for k in keys} + batch_sizes = {k: len(v) for k, v in items_in_modality.items()} + batch_size = next(iter(batch_sizes.values()), 0) + assert all(bs == batch_size + for bs in batch_sizes.values()), dict( + modality=modality, + batch_sizes=batch_sizes, + items_by_key=items_by_key) + @staticmethod def _try_stack(nested_tensors: NestedTensors) -> NestedTensors: """ @@ -139,7 +496,7 @@ def _try_stack(nested_tensors: NestedTensors) -> NestedTensors: # Only tensors (not lists) can be stacked. return stacked - tensors_ = cast(List[torch.Tensor], stacked) + tensors_ = cast(list[torch.Tensor], stacked) if any(t.shape != tensors_[0].shape for t in tensors_): # The tensors have incompatible shapes and can't be stacked. return tensors_ @@ -147,7 +504,7 @@ def _try_stack(nested_tensors: NestedTensors) -> NestedTensors: return torch.stack(tensors_) @staticmethod - def batch(inputs_list: List["MultiModalKwargs"]) -> BatchedTensorInputs: + def batch(inputs_list: list["MultiModalKwargs"]) -> BatchedTensorInputs: """ Batch multiple inputs together into a dictionary. @@ -162,7 +519,7 @@ def batch(inputs_list: List["MultiModalKwargs"]) -> BatchedTensorInputs: # We need to consider the case where each item in the batch # contains different modalities (i.e. different keys). - item_lists: Dict[str, List[NestedTensors]] = defaultdict(list) + item_lists = defaultdict[str, list[NestedTensors]](list) for inputs in inputs_list: for k, v in inputs.items(): @@ -188,6 +545,57 @@ def as_kwargs( return cast(BatchedTensorInputs, json_mapped) + def __eq__(self, other: object) -> bool: + if not isinstance(other, self.__class__): + return False + if self._items_by_key != other._items_by_key: + return False + + ks = self.keys() + return (ks == other.keys() + and all(nested_tensors_equal(self[k], other[k]) for k in ks)) + + def get_item(self, key: str, item_index: int) -> MultiModalFieldItem: + return self._items_by_key[key][item_index] + + def get_items_by_modality( + self, + modality: str, + item_index: int, + ) -> Mapping[str, MultiModalFieldItem]: + """ + Get the keyword arguments corresponding to an item identified by + its modality and index. + """ + keys_to_gather = self._keys_by_modality[modality] + + return { + key: self.get_item(key, item_index) + for key in keys_to_gather if key in self + } + + @staticmethod + def from_items_by_modality( + items_by_modality: Mapping[str, list[Mapping[str, + MultiModalFieldItem]]], + *, + enable_sanity_checks: bool = False, + ) -> "MultiModalKwargs": + """ + Construct a new :class:`MultiModalKwargs` from multiple items returned + by :meth:`get_fields_by_modality`. + """ + items_by_key = defaultdict[str, list[MultiModalFieldItem]](list) + for fields in items_by_modality.values(): + for field in fields: + for k, v in field.items(): + items_by_key[k].append(v) + + return MultiModalKwargs.from_items_by_key( + items_by_key, + enable_sanity_checks=enable_sanity_checks, + ) + MultiModalPlaceholderDict = Mapping[str, Sequence[PlaceholderRange]] """ @@ -207,16 +615,16 @@ class MultiModalInputsV2(TypedDict): prompt: str """The processed prompt text.""" - prompt_token_ids: List[int] + prompt_token_ids: list[int] """The processed token IDs which includes placeholder tokens.""" - token_type_ids: NotRequired[List[int]] + token_type_ids: NotRequired[list[int]] """The token type IDs of the prompt.""" mm_kwargs: MultiModalKwargs """Keyword arguments to be directly passed to the model after batching.""" - mm_hashes: NotRequired[List[str]] + mm_hashes: NotRequired[list[str]] """The hashes of the multi-modal data.""" mm_placeholders: MultiModalPlaceholderDict diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 6baf19d675d50..3ece0762e3228 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -1,6 +1,6 @@ +import pickle import re from abc import ABC, abstractmethod -from collections import UserDict from collections.abc import Callable, ItemsView, Iterable, Mapping, Sequence from dataclasses import dataclass, field from functools import lru_cache @@ -8,19 +8,18 @@ import numpy as np import torch +from blake3 import blake3 from PIL.Image import Image from transformers import BatchFeature, ProcessorMixin -from typing_extensions import assert_never from vllm.inputs import DummyData, InputProcessingContext from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer -from vllm.utils import flatten_2d_lists, full_groupby, is_list_of +from vllm.utils import LRUCache, flatten_2d_lists, full_groupby, is_list_of -from .audio import resample_audio -from .inputs import (AudioItem, ImageItem, MultiModalDataDict, - MultiModalInputsV2, MultiModalKwargs, PlaceholderRange, - VideoItem) +from .inputs import (MultiModalDataDict, MultiModalDataItems, + MultiModalFieldConfig, MultiModalFieldItem, + MultiModalInputsV2, MultiModalKwargs, PlaceholderRange) logger = init_logger(__name__) @@ -201,111 +200,6 @@ def get_replacement(self, item_idx: int) -> _BoundPromptSequence: return bound_replacement -class ImageSize(NamedTuple): - width: int - height: int - - -class MultiModalDataItems(UserDict[str, list[Any]]): - """ - As :class:`MultiModalDataDict`, but normalized such that each entry - corresponds to a list. - """ - - @staticmethod - def from_dict(data: MultiModalDataDict) -> "MultiModalDataItems": - """ - Normalize :class:`MultiModalDataDict` to :class:`MultiModalDataItems`. - """ - multi_data = MultiModalDataItems() - - for k, v in data.items(): - # TODO: Make a separate modality for embedding inputs - # to avoid confusion - # yapf: disable - if k == "video": - # Special case since even a single item can be a list - multi_data[k] = ( # type: ignore[index] - v if (isinstance(v, torch.Tensor) - or is_list_of(v, list)) else [v] - ) - elif k in ("image", "audio"): - multi_data[k] = ( # type: ignore[index] - v if isinstance(v, (torch.Tensor, list)) else [v] - ) - else: - multi_data[k] = v if isinstance(v, list) else [v] # type: ignore[index] - # yapf: enable - - return multi_data - - # NOTE: When a field (e.g. `images`) doesn't exist, directly appending to - # `self.images` doesn't update this dictionary, which may be confusing - # We annotate the getter methods as `Sequence` to prevent others from - # trying to update the list in this way - @property - def images(self) -> Sequence[ImageItem]: - return self.get("image", []) - - @property - def videos(self) -> Sequence[VideoItem]: - return self.get("video", []) - - @property - def audios(self) -> Sequence[AudioItem]: - return self.get("audio", []) - - def get_item_counts(self) -> Mapping[str, int]: - return {m: len(items) for m, items in self.items()} - - def get_image_size(self, item_idx: int) -> ImageSize: - image = self.images[item_idx] - - if isinstance(image, Image): - return ImageSize(*image.size) - if isinstance(image, (np.ndarray, torch.Tensor)): - _, h, w = image.shape - return ImageSize(w, h) - - assert_never(image) - - def get_audio_with_sr( - self, - item_idx: int, - *, - default_sr: float, - ) -> tuple[np.ndarray, float]: - audio = self.audios[item_idx] - - if isinstance(audio, tuple): - return audio - if isinstance(audio, list): - return np.array(audio), default_sr - if isinstance(audio, np.ndarray): - return audio, default_sr - - assert_never(audio) - - def resample_audios(self, new_sr: float, *, drop_sr: bool = True) -> None: - """ - If :code:`drop_sr=True`, the audio items in this dictionary are updated - to be NumPy arrays which implicitly means that their sampling rate is - the same as the model's expected sampling rate; otherwise, they remain - as :code:`(audio, new_sr)` tuples. - """ - if not self.audios: - return - - new_audios = [] - for item_idx in range(len(self.audios)): - audio, sr = self.get_audio_with_sr(item_idx, default_sr=new_sr) - audio = resample_audio(audio, orig_sr=sr, target_sr=new_sr) - - new_audios.append(audio if drop_sr else (audio, new_sr)) - - self["audio"] = new_audios - - class _TokenMatch(NamedTuple): start_idx: int end_idx: int @@ -583,11 +477,124 @@ def iter_placeholders( ) -class ProcessorInputs(NamedTuple): - """Keyword arguments to :meth:`BaseMultiModalProcessor`""" +@dataclass +class ProcessorInputs: + """Keyword arguments to :meth:`BaseMultiModalProcessor`.""" prompt_text: str mm_data: MultiModalDataDict - mm_processor_kwargs: Mapping[str, object] + hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict) + + +class ProcessingCache: + + def __init__(self, capacity: int) -> None: + super().__init__() + + # DEBUG: Set to None to disable + self.debug_cache_hit_ratio_steps: Optional[int] = None + + self._cache = LRUCache[str, Mapping[str, + MultiModalFieldItem]](capacity) + + def _maybe_log_cache_stats(self) -> None: + steps = self.debug_cache_hit_ratio_steps + if not steps: + return + + cache_stats = self._cache.stat() + if cache_stats.total % steps == 0: + logger.debug("ProcessingCache: hit_ratio = %.2f", + cache_stats.hit_ratio) + + def _serialize_item(self, obj: object) -> bytes: + # Simple cases + if isinstance(obj, str): + return obj.encode("utf-8") + if isinstance(obj, bytes): + return obj + if isinstance(obj, Image): + return obj.tobytes() + + # Convertible to NumPy arrays + if isinstance(obj, torch.Tensor): + obj = obj.numpy() + if isinstance(obj, (int, float)): + obj = np.array(obj) + if isinstance(obj, np.ndarray): + return obj.tobytes() + + logger.warning( + "No serialization method found for %s. " + "Falling back to pickle.", type(obj)) + + return pickle.dumps(obj) + + def _item_to_bytes( + self, + key: str, + obj: object, + ) -> Iterable[tuple[bytes, bytes]]: + # Recursive cases + if isinstance(obj, (list, tuple)): + for i, elem in enumerate(obj): + yield from self._item_to_bytes(f"{key}.{i}", elem) + elif isinstance(obj, dict): + for k, v in obj.items(): + yield from self._item_to_bytes(f"{key}.{k}", v) + else: + key_bytes = self._serialize_item(key) + value_bytes = self._serialize_item(obj) + yield key_bytes, value_bytes + + def _hash_kwargs(self, **kwargs: object) -> str: + hasher = blake3() + + for k, v in kwargs.items(): + for k_bytes, v_bytes in self._item_to_bytes(k, v): + hasher.update(k_bytes) + hasher.update(v_bytes) + + return hasher.hexdigest() + + def get( + self, + model_id: str, + modality: str, + input_item: object, + input_kwargs: Mapping[str, object], + ) -> Optional[Mapping[str, MultiModalFieldItem]]: + """ + Get a processed multi-modal item from the cache + according to its dependencies, including: + + - The model ID + - The modality of the item + - The original data item passed to the HF processor + - The configuration options of the HF processor + """ + self._maybe_log_cache_stats() + + cache_key = self._hash_kwargs(model_id=model_id, + **{modality: input_item}, + **input_kwargs) + return self._cache.get(cache_key) + + def put( + self, + model_id: str, + modality: str, + input_item: object, + input_kwargs: Mapping[str, object], + output_kwargs: Mapping[str, MultiModalFieldItem], + ) -> None: + """ + Put a processed multi-modal item into the cache + according to its dependencies (see :meth:`get`). + """ + cache_key = self._hash_kwargs(model_id=model_id, + **{modality: input_item}, + **input_kwargs) + self._cache.put(cache_key, output_kwargs) class BaseMultiModalProcessor(ABC): @@ -595,18 +602,24 @@ class BaseMultiModalProcessor(ABC): Abstract base class to process multi-modal inputs to be used in vLLM. """ - def __init__(self, ctx: InputProcessingContext) -> None: + def __init__(self, + ctx: InputProcessingContext, + *, + cache: Optional[ProcessingCache] = None, + enable_sanity_checks: bool = True) -> None: super().__init__() self.ctx = ctx + self.cache = cache + self.enable_sanity_checks = enable_sanity_checks def __call__( self, prompt: str, mm_data: MultiModalDataDict, - mm_processor_kwargs: Mapping[str, object], + hf_processor_mm_kwargs: Mapping[str, object], ) -> MultiModalInputsV2: - return self.apply(prompt, mm_data, mm_processor_kwargs) + return self.apply(prompt, mm_data, hf_processor_mm_kwargs) def _get_hf_processor(self) -> ProcessorMixin: """ @@ -624,12 +637,21 @@ def _get_mm_items( ) -> MultiModalDataItems: return MultiModalDataItems.from_dict(mm_data) + @abstractmethod + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + """Given the HF-processed data, output the metadata of each field.""" + raise NotImplementedError + @abstractmethod def _get_prompt_replacements( self, mm_items: MultiModalDataItems, - hf_inputs: BatchFeature, - mm_processor_kwargs: Mapping[str, object], + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, ) -> list[PromptReplacement]: """ Given the original multi-modal items for this modality @@ -651,7 +673,7 @@ def _find_placeholders( return list( iter_placeholders(all_prompt_repls, new_token_ids, mm_item_counts)) - def _get_processor_data( + def _get_hf_mm_data( self, mm_items: MultiModalDataItems, ) -> tuple[dict[str, Any], dict[str, Any]]: @@ -669,7 +691,7 @@ def _get_processor_data( and v[0].ndim == 2): # Pass through embedding inputs (multi) passthrough_data[f"{k}_embeds"] = v - else: + elif len(v) > 0: # Map keys to plural form, e.g.: image -> images processor_data[f"{k}s"] = v else: @@ -679,39 +701,181 @@ def _get_processor_data( def _call_hf_processor( self, - hf_processor: ProcessorMixin, prompt: str, - processor_data: Mapping[str, object], - mm_processor_kwargs: Mapping[str, object], + # Not to be confused with `mm_data` in `self.apply`. + # This refers to the data to be passed to HF processor. + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], ) -> BatchFeature: return self.ctx.call_hf_processor( - hf_processor, - prompt, - processor_data, - mm_processor_kwargs, + self._get_hf_processor(**mm_kwargs), + dict(text=prompt, **mm_data), + mm_kwargs, ) def _apply_hf_processor( self, - prompt: str, + prompt_text: str, mm_items: MultiModalDataItems, - mm_processor_kwargs: Mapping[str, object], - ) -> BatchFeature: - # some mm_processor_kwargs may be used in processor initialization - # instead of processor call - hf_processor = self._get_hf_processor(**mm_processor_kwargs) + hf_processor_mm_kwargs: Mapping[str, object], + ) -> tuple[list[int], MultiModalKwargs]: + """ + Apply the HF processor on the full prompt text and multi-modal data. + """ + processor_data, passthrough_data = self._get_hf_mm_data(mm_items) + + processed_data = self._call_hf_processor( + prompt=prompt_text, + mm_data=processor_data, + mm_kwargs=hf_processor_mm_kwargs, + ) + processed_data.update(passthrough_data) - processor_data, passthrough_data = self._get_processor_data(mm_items) + prompt_ids, = processed_data.pop("input_ids").tolist() - hf_inputs = self._call_hf_processor( - hf_processor, - prompt=prompt, - processor_data=processor_data, - mm_processor_kwargs=mm_processor_kwargs, + mm_kwargs = MultiModalKwargs.from_hf_inputs( + processed_data, + self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs), + enable_sanity_checks=self.enable_sanity_checks, ) - hf_inputs.update(passthrough_data) - return hf_inputs + return prompt_ids, mm_kwargs + + def _apply_hf_processor_missing( + self, + prompt_text: str, + mm_missing_data_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + ): + """ + Apply the HF processor on the full prompt text, but only on the + multi-modal data that are missing from the cache. + + Note: We pass prompt text and multi-modal data into the HF processor + in separate calls to avoid HF prompt replacement being done for + cached items; instead, we rely on our own prompt replacement logic + for the full text. + """ + mm_missing_counts = mm_missing_data_items.get_item_counts() + + prompt_ids, _ = self._apply_hf_processor( + prompt_text=prompt_text, + mm_items=MultiModalDataItems({}), + hf_processor_mm_kwargs={}, + ) + + # Some HF processors (e.g. Qwen2-VL) expect corresponding + # multi-modal tokens to be in the prompt text + dummy_inputs = self._get_dummy_mm_inputs(mm_missing_counts) + + _, mm_missing_kwargs = self._apply_hf_processor( + prompt_text=dummy_inputs.prompt_text, + mm_items=mm_missing_data_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + ) + + return prompt_ids, mm_missing_kwargs + + def _cached_apply_hf_processor( + self, + prompt_text: str, + mm_data_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> tuple[list[int], MultiModalKwargs]: + """ + Apply the HF processor on the full prompt text, + caching the results and reusing cached results. + """ + cache = self.cache + model_id = self.ctx.model_config.model + + if cache is None or mm_data_items.has_embedding_inputs(): + return self._apply_hf_processor( + prompt_text=prompt_text, + mm_items=mm_data_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + ) + + mm_maybe_cached_field_items = { + modality: [ + cache.get(model_id, modality, item, hf_processor_mm_kwargs) + for item in items + ] + for modality, items in mm_data_items.items() + } + + mm_missing_idxs = { + modality: [idx for idx, out in enumerate(fields) if out is None] + for modality, fields in mm_maybe_cached_field_items.items() + } + mm_missing_data = { + modality: [mm_data_items[modality][idx] for idx in idxs] + for modality, idxs in mm_missing_idxs.items() + } + mm_missing_data_items = self._get_mm_items(mm_missing_data) + + prompt_ids, mm_missing_kwargs = self._apply_hf_processor_missing( + prompt_text=prompt_text, + mm_missing_data_items=mm_missing_data_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + ) + + mm_missing_next_idx = { + modality: 0 + for modality in mm_missing_data_items + } + + mm_merged_field_items = dict[str, list[Mapping[str, + MultiModalFieldItem]]]() + for modality, modal_items_lst in mm_maybe_cached_field_items.items(): + merged_modal_items_lst = list[Mapping[str, MultiModalFieldItem]]() + + for idx, modal_items in enumerate(modal_items_lst): + if modal_items is None: + modal_items = mm_missing_kwargs.get_items_by_modality( + modality, + mm_missing_next_idx[modality], + ) + + cache.put( + model_id, + modality, + mm_data_items[modality][idx], + hf_processor_mm_kwargs, + modal_items, + ) + + mm_missing_next_idx[modality] += 1 + + merged_modal_items_lst.append(modal_items) + + mm_merged_field_items[modality] = merged_modal_items_lst + + if self.enable_sanity_checks: + mm_missing_counts = mm_missing_data_items.get_item_counts() + assert all( + item_count == mm_missing_counts[modality] + for modality, item_count in mm_missing_next_idx.items()), dict( + mm_missing_next_idx=mm_missing_next_idx, + mm_missing_counts=mm_missing_counts) + + mm_kwargs = MultiModalKwargs.from_items_by_modality( + mm_merged_field_items, + enable_sanity_checks=self.enable_sanity_checks, + ) + + if self.enable_sanity_checks: + mm_item_counts = mm_data_items.get_item_counts() + + for modality, item_count in mm_item_counts.items(): + for item_idx in range(item_count): + try: + mm_kwargs.get_items_by_modality(modality, item_idx) + except Exception as e: + # Make it easy to set a breakpoint in the debugger + raise e + + return prompt_ids, mm_kwargs def _bind_prompt_replacements( self, @@ -730,6 +894,10 @@ def _apply_prompt_replacements( tokenizer = self._get_tokenizer() token_matches = find_token_matches(token_ids, prompt_repls) + mm_match_counts = { + modality: len(matches) + for modality, matches in full_groupby_modality(token_matches) + } # If the search text does not represent a special token, # it may have different token IDs in the prompt, because @@ -742,8 +910,8 @@ def _apply_prompt_replacements( # of the search text in the prompt, we instead perform string # replacement on the decoded token IDs, then encode them back. if all( - len(matches) >= mm_item_counts[modality] - for modality, matches in full_groupby_modality(token_matches) + mm_match_counts.get(modality, 0) >= item_count + for modality, item_count in mm_item_counts.items() ): # yapf: disable token_ids = replace_token_matches( token_ids, @@ -775,7 +943,7 @@ def apply( self, prompt_text: str, mm_data: MultiModalDataDict, - mm_processor_kwargs: Mapping[str, object], + hf_processor_mm_kwargs: Mapping[str, object], ) -> MultiModalInputsV2: """ Process multi-modal inputs to be used in vLLM. @@ -792,20 +960,24 @@ def apply( """ mm_items = self._get_mm_items(mm_data) - hf_inputs = self._apply_hf_processor(prompt_text, mm_items, - mm_processor_kwargs) - prompt_ids, = hf_inputs.pop("input_ids").tolist() - mm_kwargs = MultiModalKwargs(hf_inputs) + prompt_ids, mm_kwargs = self._cached_apply_hf_processor( + prompt_text, + mm_items, + hf_processor_mm_kwargs, + ) - prompt_repls = self._get_prompt_replacements(mm_items, hf_inputs, - mm_processor_kwargs) - all_prompt_repls = self._bind_prompt_replacements(prompt_repls) + unbound_prompt_repls = self._get_prompt_replacements( + mm_items, + hf_processor_mm_kwargs, + mm_kwargs, + ) + prompt_repls = self._bind_prompt_replacements(unbound_prompt_repls) # If HF processor already inserts placeholder tokens, # there is no need for us to insert them mm_item_counts = mm_items.get_item_counts() - all_placeholders = self._find_placeholders(all_prompt_repls, - prompt_ids, mm_item_counts) + all_placeholders = self._find_placeholders(prompt_repls, prompt_ids, + mm_item_counts) if all_placeholders: tokenizer = self._get_tokenizer() @@ -817,7 +989,7 @@ def apply( all_placeholders, ) = self._apply_prompt_replacements( prompt_ids, - all_prompt_repls, + prompt_repls, mm_item_counts, ) @@ -855,23 +1027,29 @@ def get_dummy_data( from vllm.sequence import SequenceData processor_inputs = self._get_dummy_mm_inputs(mm_counts) - mm_inputs = self.apply(*processor_inputs) + mm_inputs = self.apply( + prompt_text=processor_inputs.prompt_text, + mm_data=processor_inputs.mm_data, + hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs, + ) prompt_token_ids = mm_inputs["prompt_token_ids"] placeholders_by_modality = mm_inputs["mm_placeholders"] - total_placeholders_by_modality = dict[str, int]() - for modality, placeholders in placeholders_by_modality.items(): - num_placeholders = sum(item["length"] for item in placeholders) - max_tokens = mm_max_tokens[modality] - - if num_placeholders != max_tokens: - logger.warning( - "The processed dummy data has a total of %d placeholder " - "tokens for the '%s' modality, which is not the expected " - "%d tokens.", num_placeholders, modality, max_tokens) - - total_placeholders_by_modality[modality] = num_placeholders + total_placeholders_by_modality = { + modality: sum(item["length"] for item in placeholders) + for modality, placeholders in placeholders_by_modality.items() + } + expected_placeholders_by_modality = { + modality: mm_max_tokens[modality] + for modality in placeholders_by_modality + } + if total_placeholders_by_modality != expected_placeholders_by_modality: + raise AssertionError( + f"The processed dummy data has a total of " + f"{total_placeholders_by_modality} placeholder tokens, which " + f"is not the expected {expected_placeholders_by_modality} " + "tokens.") total_len = len(prompt_token_ids) if total_len > seq_len: diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index ded45a7184b5d..3a5e11867ad9e 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -1,10 +1,9 @@ import functools from collections import UserDict -from typing import (TYPE_CHECKING, Any, Callable, Dict, Mapping, Optional, +from typing import (TYPE_CHECKING, Any, Dict, Mapping, Optional, Protocol, Sequence, Type, TypeVar) import torch.nn as nn -from typing_extensions import TypeAlias from vllm.inputs import InputProcessingContext from vllm.logger import init_logger @@ -15,7 +14,7 @@ from .base import MultiModalInputMapper, MultiModalPlugin, MultiModalTokensCalc from .image import ImagePlugin from .inputs import MultiModalDataDict, MultiModalKwargs, NestedTensors -from .processing import BaseMultiModalProcessor +from .processing import BaseMultiModalProcessor, ProcessingCache from .video import VideoPlugin if TYPE_CHECKING: @@ -23,15 +22,22 @@ logger = init_logger(__name__) +# TODO: Tune the MM cache size +MM_CACHE_SIZE = 256 + N = TypeVar("N", bound=Type[nn.Module]) -MultiModalProcessorFactory: TypeAlias = Callable[[InputProcessingContext], - BaseMultiModalProcessor] -""" -Constructs a :class:`MultiModalProcessor` instance from the context. -The processing metadata should be derived from the context. -""" +class MultiModalProcessorFactory(Protocol): + """Constructs a :class:`MultiModalProcessor` instance from the context.""" + + def __call__( + self, + ctx: InputProcessingContext, + *, + cache: Optional[ProcessingCache] = None, + ) -> BaseMultiModalProcessor: + ... class _MultiModalLimits(UserDict["ModelConfig", Dict[str, int]]): @@ -71,6 +77,8 @@ def __init__( self._limits_by_model = _MultiModalLimits() + self._processing_cache = ProcessingCache(MM_CACHE_SIZE) + def register_plugin(self, plugin: MultiModalPlugin) -> None: """ Register a multi-modal plugin so it can be recognized by vLLM. @@ -328,15 +336,18 @@ def wrapper(model_cls: N) -> N: return wrapper - def has_processor(self, model_config: "ModelConfig") -> bool: - """ - Test whether a multi-modal processor is defined for a specific model. - """ + def _get_model_cls(self, model_config: "ModelConfig"): # Avoid circular import from vllm.model_executor.model_loader import get_model_architecture model_cls, _ = get_model_architecture(model_config) - return model_cls in self._processor_factories + return model_cls + + def has_processor(self, model_config: "ModelConfig") -> bool: + """ + Test whether a multi-modal processor is defined for a specific model. + """ + return self._get_model_cls(model_config) in self._processor_factories def create_processor( self, @@ -346,12 +357,11 @@ def create_processor( """ Create a multi-modal processor for a specific model and tokenizer. """ - - # Avoid circular import - from vllm.model_executor.model_loader import get_model_architecture - - model_cls, _ = get_model_architecture(model_config) + model_cls = self._get_model_cls(model_config) processor_factory = self._processor_factories[model_cls] ctx = InputProcessingContext(model_config, tokenizer) - return processor_factory(ctx) + cache = (None if model_config.disable_mm_preprocessor_cache else + self._processing_cache) + + return processor_factory(ctx, cache=cache) diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index a49da2bdee972..87b12a6fb33c1 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -1,8 +1,7 @@ -import base64 -import os from functools import lru_cache -from io import BytesIO -from typing import List, Optional, Tuple, TypeVar, Union +from pathlib import Path +from typing import Optional, TypeVar, Union +from urllib.parse import ParseResult, urlparse import numpy as np import numpy.typing as npt @@ -10,283 +9,246 @@ from PIL import Image import vllm.envs as envs -from vllm.connections import global_http_connection +from vllm.connections import HTTPConnection, global_http_connection from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer -from vllm.utils import PlaceholderModule -from .inputs import MultiModalDataDict, PlaceholderRange - -try: - import decord -except ImportError: - decord = PlaceholderModule("decord") # type: ignore[assignment] - -try: - import librosa -except ImportError: - librosa = PlaceholderModule("librosa") # type: ignore[assignment] - -try: - import soundfile -except ImportError: - soundfile = PlaceholderModule("soundfile") # type: ignore[assignment] +from .audio import AudioMediaIO +from .base import MediaIO +from .image import ImageMediaIO +from .inputs import PlaceholderRange +from .video import VideoMediaIO logger = init_logger(__name__) cached_get_tokenizer = lru_cache(get_tokenizer) +_M = TypeVar("_M") -def _load_image_from_bytes(b: bytes) -> Image.Image: - image = Image.open(BytesIO(b)) - image.load() - return image - - -def _is_subpath(image_path: str, allowed_local_media_path: str) -> bool: - # Get the common path - common_path = os.path.commonpath([ - os.path.abspath(image_path), - os.path.abspath(allowed_local_media_path) - ]) - # Check if the common path is the same as allowed_local_media_path - return common_path == os.path.abspath(allowed_local_media_path) +class MediaConnector: -def _load_image_from_file(image_url: str, - allowed_local_media_path: str) -> Image.Image: - if not allowed_local_media_path: - raise ValueError("Invalid 'image_url': Cannot load local files without" - "'--allowed-local-media-path'.") - if allowed_local_media_path: - if not os.path.exists(allowed_local_media_path): - raise ValueError( - "Invalid '--allowed-local-media-path': " - f"The path {allowed_local_media_path} does not exist.") - if not os.path.isdir(allowed_local_media_path): + def __init__( + self, + connection: HTTPConnection = global_http_connection, + *, + allowed_local_media_path: str = "", + ) -> None: + super().__init__() + + self.connection = connection + + if allowed_local_media_path: + allowed_local_media_path_ = Path(allowed_local_media_path) + + if not allowed_local_media_path_.exists(): + raise ValueError( + "Invalid `--allowed-local-media-path`: The path " + f"{allowed_local_media_path_} does not exist.") + if not allowed_local_media_path_.is_dir(): + raise ValueError( + "Invalid `--allowed-local-media-path`: The path " + f"{allowed_local_media_path_} must be a directory.") + else: + allowed_local_media_path_ = None + + self.allowed_local_media_path = allowed_local_media_path_ + + def _load_data_url( + self, + url_spec: ParseResult, + media_io: MediaIO[_M], + ) -> _M: + data_spec, data = url_spec.path.split(",", 1) + media_type, data_type = data_spec.split(";", 1) + + if data_type != "base64": + msg = "Only base64 data URLs are supported for now." + raise NotImplementedError(msg) + + return media_io.load_base64(media_type, data) + + def _load_file_url( + self, + url_spec: ParseResult, + media_io: MediaIO[_M], + ) -> _M: + allowed_local_media_path = self.allowed_local_media_path + if allowed_local_media_path is None: + raise RuntimeError("Cannot load local files without " + "`--allowed-local-media-path`.") + + filepath = Path(url_spec.path) + if allowed_local_media_path not in filepath.resolve().parents: raise ValueError( - "Invalid '--allowed-local-media-path': " - f"The path {allowed_local_media_path} must be a directory.") - - # Only split once and assume the second part is the image path - _, image_path = image_url.split("file://", 1) - if not _is_subpath(image_path, allowed_local_media_path): - raise ValueError( - f"Invalid 'image_url': The file path {image_path} must" - " be a subpath of '--allowed-local-media-path'" - f" '{allowed_local_media_path}'.") - - image = Image.open(image_path) - image.load() - return image + f"The file path {filepath} must be a subpath " + f"of `--allowed-local-media-path` {allowed_local_media_path}.") + return media_io.load_file(filepath) -def _load_image_from_data_url(image_url: str) -> Image.Image: - # Only split once and assume the second part is the base64 encoded image - _, image_base64 = image_url.split(",", 1) - return load_image_from_base64(image_base64) - - -def fetch_image(image_url: str, - *, - image_mode: str = "RGB", - allowed_local_media_path: str = "") -> Image.Image: - """ - Load a PIL image from a HTTP or base64 data URL. - - By default, the image is converted into RGB format. - """ - if image_url.startswith('http'): - image_raw = global_http_connection.get_bytes( - image_url, - timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT, - ) - image = _load_image_from_bytes(image_raw) - - elif image_url.startswith('data:image'): - image = _load_image_from_data_url(image_url) - elif image_url.startswith('file://'): - image = _load_image_from_file(image_url, allowed_local_media_path) - else: - raise ValueError("Invalid 'image_url': A valid 'image_url' must start " - "with either 'data:image', 'file://' or 'http'.") - - return image.convert(image_mode) - - -async def async_fetch_image(image_url: str, - *, - image_mode: str = "RGB", - allowed_local_media_path: str = "") -> Image.Image: - """ - Asynchronously load a PIL image from a HTTP or base64 data URL. - - By default, the image is converted into RGB format. - """ - if image_url.startswith('http'): - image_raw = await global_http_connection.async_get_bytes( - image_url, - timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT, - ) - image = _load_image_from_bytes(image_raw) - - elif image_url.startswith('data:image'): - image = _load_image_from_data_url(image_url) - elif image_url.startswith('file://'): - image = _load_image_from_file(image_url, allowed_local_media_path) - else: - raise ValueError("Invalid 'image_url': A valid 'image_url' must start " - "with either 'data:image', 'file://' or 'http'.") + def load_from_url( + self, + url: str, + media_io: MediaIO[_M], + *, + fetch_timeout: Optional[int] = None, + ) -> _M: + url_spec = urlparse(url) - return image.convert(image_mode) + if url_spec.scheme.startswith("http"): + connection = self.connection + data = connection.get_bytes(url, timeout=fetch_timeout) + return media_io.load_bytes(data) -def _load_video_from_bytes(b: bytes, num_frames: int = 32) -> npt.NDArray: - video_path = BytesIO(b) - vr = decord.VideoReader(video_path, num_threads=1) - total_frame_num = len(vr) + if url_spec.scheme == "data": + return self._load_data_url(url_spec, media_io) - if total_frame_num > num_frames: - uniform_sampled_frames = np.linspace(0, - total_frame_num - 1, - num_frames, - dtype=int) - frame_idx = uniform_sampled_frames.tolist() - else: - frame_idx = [i for i in range(0, total_frame_num)] - frames = vr.get_batch(frame_idx).asnumpy() + if url_spec.scheme == "file": + return self._load_file_url(url_spec, media_io) - return frames + msg = "The URL must be either a HTTP, data or file URL." + raise ValueError(msg) + async def load_from_url_async( + self, + url: str, + media_io: MediaIO[_M], + *, + fetch_timeout: Optional[int] = None, + ) -> _M: + url_spec = urlparse(url) -def _load_video_from_data_url(video_url: str) -> npt.NDArray: - # Only split once and assume the second part is the base64 encoded video - _, video_base64 = video_url.split(",", 1) + if url_spec.scheme.startswith("http"): + connection = self.connection + data = await connection.async_get_bytes(url, timeout=fetch_timeout) - if video_url.startswith("data:video/jpeg;"): - return np.stack([ - np.array(load_image_from_base64(frame_base64)) - for frame_base64 in video_base64.split(",") - ]) + return media_io.load_bytes(data) - return load_video_from_base64(video_base64) + if url_spec.scheme == "data": + return self._load_data_url(url_spec, media_io) + if url_spec.scheme == "file": + return self._load_file_url(url_spec, media_io) -def fetch_video(video_url: str, *, num_frames: int = 32) -> npt.NDArray: - """ - Load video from a HTTP or base64 data URL. - """ - if video_url.startswith('http') or video_url.startswith('https'): - video_raw = global_http_connection.get_bytes( - video_url, - timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT, - ) - video = _load_video_from_bytes(video_raw, num_frames) - elif video_url.startswith('data:video'): - video = _load_video_from_data_url(video_url) - else: - raise ValueError("Invalid 'video_url': A valid 'video_url' must start " - "with either 'data:video' or 'http'.") - return video + msg = "The URL must be either a HTTP, data or file URL." + raise ValueError(msg) + def fetch_audio( + self, + audio_url: str, + ) -> tuple[np.ndarray, Union[int, float]]: + """ + Load audio from a URL. + """ + audio_io = AudioMediaIO() -async def async_fetch_video(video_url: str, - *, - num_frames: int = 32) -> npt.NDArray: - """ - Asynchronously load video from a HTTP or base64 data URL. - - By default, the image is converted into RGB format. - """ - if video_url.startswith('http') or video_url.startswith('https'): - video_raw = await global_http_connection.async_get_bytes( - video_url, - timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT, - ) - video = _load_video_from_bytes(video_raw, num_frames) - elif video_url.startswith('data:video'): - video = _load_video_from_data_url(video_url) - else: - raise ValueError("Invalid 'video_url': A valid 'video_url' must start " - "with either 'data:video' or 'http'.") - return video - - -def fetch_audio(audio_url: str) -> Tuple[np.ndarray, Union[int, float]]: - """ - Load audio from a URL. - """ - if audio_url.startswith("http"): - audio_bytes = global_http_connection.get_bytes( + return self.load_from_url( audio_url, - timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT, + audio_io, + fetch_timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT, ) - elif audio_url.startswith("data:audio"): - _, audio_base64 = audio_url.split(",", 1) - audio_bytes = base64.b64decode(audio_base64) - else: - raise ValueError("Invalid 'audio_url': A valid 'audio_url' must start " - "with either 'data:audio' or 'http'.") - - return librosa.load(BytesIO(audio_bytes), sr=None) + async def fetch_audio_async( + self, + audio_url: str, + ) -> tuple[np.ndarray, Union[int, float]]: + """ + Asynchronously fetch audio from a URL. + """ + audio_io = AudioMediaIO() -async def async_fetch_audio( - audio_url: str) -> Tuple[np.ndarray, Union[int, float]]: - """ - Asynchronously fetch audio from a URL. - """ - if audio_url.startswith("http"): - audio_bytes = await global_http_connection.async_get_bytes( + return await self.load_from_url_async( audio_url, - timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT, + audio_io, + fetch_timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT, ) - elif audio_url.startswith("data:audio"): - _, audio_base64 = audio_url.split(",", 1) - audio_bytes = base64.b64decode(audio_base64) - else: - raise ValueError("Invalid 'audio_url': A valid 'audio_url' must start " - "with either 'data:audio' or 'http'.") - - return librosa.load(BytesIO(audio_bytes), sr=None) + def fetch_image( + self, + image_url: str, + *, + image_mode: str = "RGB", + ) -> Image.Image: + """ + Load a PIL image from a HTTP or base64 data URL. -def get_and_parse_audio(audio_url: str) -> MultiModalDataDict: - audio, sr = fetch_audio(audio_url) - return {"audio": (audio, sr)} + By default, the image is converted into RGB format. + """ + image_io = ImageMediaIO(image_mode=image_mode) + return self.load_from_url( + image_url, + image_io, + fetch_timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT, + ) -def get_and_parse_image( + async def fetch_image_async( + self, image_url: str, *, - allowed_local_media_path: str = "") -> MultiModalDataDict: - image = fetch_image(image_url, - allowed_local_media_path=allowed_local_media_path) - return {"image": image} - + image_mode: str = "RGB", + ) -> Image.Image: + """ + Asynchronously load a PIL image from a HTTP or base64 data URL. -def get_and_parse_video(video_url: str) -> MultiModalDataDict: - video = fetch_video(video_url) - return {"video": video} + By default, the image is converted into RGB format. + """ + image_io = ImageMediaIO(image_mode=image_mode) + return await self.load_from_url_async( + image_url, + image_io, + fetch_timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT, + ) -async def async_get_and_parse_audio(audio_url: str) -> MultiModalDataDict: - audio, sr = await async_fetch_audio(audio_url) - return {"audio": (audio, sr)} - + def fetch_video( + self, + video_url: str, + *, + image_mode: str = "RGB", + num_frames: int = 32, + ) -> npt.NDArray: + """ + Load video from a HTTP or base64 data URL. + """ + image_io = ImageMediaIO(image_mode=image_mode) + video_io = VideoMediaIO(image_io, num_frames=num_frames) + + return self.load_from_url( + video_url, + video_io, + fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT, + ) -async def async_get_and_parse_image( - image_url: str, + async def fetch_video_async( + self, + video_url: str, *, - allowed_local_media_path: str = "") -> MultiModalDataDict: - image = await async_fetch_image( - image_url, allowed_local_media_path=allowed_local_media_path) - return {"image": image} + image_mode: str = "RGB", + num_frames: int = 32, + ) -> npt.NDArray: + """ + Asynchronously load video from a HTTP or base64 data URL. + + By default, the image is converted into RGB format. + """ + image_io = ImageMediaIO(image_mode=image_mode) + video_io = VideoMediaIO(image_io, num_frames=num_frames) + + return await self.load_from_url_async( + video_url, + video_io, + fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT, + ) -async def async_get_and_parse_video(video_url: str) -> MultiModalDataDict: - video = await async_fetch_video(video_url) - return {"video": video} +global_media_connector = MediaConnector() +"""The global :class:`MediaConnector` instance used by vLLM.""" + +fetch_audio = global_media_connector.fetch_audio +fetch_image = global_media_connector.fetch_image +fetch_video = global_media_connector.fetch_video def encode_audio_base64( @@ -294,10 +256,8 @@ def encode_audio_base64( sampling_rate: int, ) -> str: """Encode audio as base64.""" - buffered = BytesIO() - soundfile.write(buffered, audio, sampling_rate, format="WAV") - - return base64.b64encode(buffered.getvalue()).decode('utf-8') + audio_io = AudioMediaIO() + return audio_io.encode_base64((audio, sampling_rate)) def encode_image_base64( @@ -311,29 +271,14 @@ def encode_image_base64( By default, the image is converted into RGB format before being encoded. """ - buffered = BytesIO() - image = image.convert(image_mode) - image.save(buffered, format) - return base64.b64encode(buffered.getvalue()).decode('utf-8') - - -def load_image_from_base64(image: Union[bytes, str]) -> Image.Image: - """Load image from base64 format.""" - return _load_image_from_bytes(base64.b64decode(image)) + image_io = ImageMediaIO(image_mode=image_mode) + return image_io.encode_base64(image, image_format=format) def encode_video_base64(frames: npt.NDArray) -> str: - base64_frames = [] - frames_list = [frames[i] for i in range(frames.shape[0])] - for frame in frames_list: - img_base64 = encode_image_base64(Image.fromarray(frame)) - base64_frames.append(img_base64) - return ",".join(base64_frames) - - -def load_video_from_base64(video: Union[bytes, str]) -> npt.NDArray: - """Load video from base64 format.""" - return _load_video_from_bytes(base64.b64decode(video)) + image_io = ImageMediaIO() + video_io = VideoMediaIO(image_io) + return video_io.encode_base64(frames) def resolve_visual_encoder_outputs( @@ -389,7 +334,7 @@ def repeat_and_pad_token( repeat_count: int = 1, pad_token_left: Optional[_T] = None, pad_token_right: Optional[_T] = None, -) -> List[_T]: +) -> list[_T]: replacement = [token] * repeat_count if pad_token_left is not None: replacement = [pad_token_left] + replacement @@ -402,13 +347,13 @@ def repeat_and_pad_token( def repeat_and_pad_placeholder_tokens( tokenizer: AnyTokenizer, prompt: Optional[str], - prompt_token_ids: List[int], + prompt_token_ids: list[int], *, placeholder_token_id: int, - repeat_count: Union[int, List[int]], + repeat_count: Union[int, list[int]], pad_token_left: Optional[int] = None, pad_token_right: Optional[int] = None, -) -> Tuple[Optional[str], List[int], List[PlaceholderRange]]: +) -> tuple[Optional[str], list[int], list[PlaceholderRange]]: if isinstance(repeat_count, int): repeat_count = [repeat_count] @@ -450,8 +395,8 @@ def repeat_and_pad_placeholder_tokens( new_prompt += prompt_parts[i] + replacement_str new_prompt += prompt_parts[-1] - new_token_ids: List[int] = [] - placeholder_ranges: List[PlaceholderRange] = [] + new_token_ids = list[int]() + placeholder_ranges = list[PlaceholderRange]() placeholder_token_idx = 0 for i, token in enumerate(prompt_token_ids): if token == placeholder_token_id: @@ -481,7 +426,7 @@ def repeat_and_pad_placeholder_tokens( def consecutive_placeholder_ranges( num_items: int, item_size: int, - initial_offset: int = 0) -> List[PlaceholderRange]: + initial_offset: int = 0) -> list[PlaceholderRange]: """Returns a list of consecutive PlaceholderRanges of a fixed size""" return [ diff --git a/vllm/multimodal/video.py b/vllm/multimodal/video.py index c4be100562703..b7d43c830cc46 100644 --- a/vllm/multimodal/video.py +++ b/vllm/multimodal/video.py @@ -1,23 +1,32 @@ -from functools import lru_cache +import base64 +from functools import lru_cache, partial +from io import BytesIO +from pathlib import Path from typing import TYPE_CHECKING, Any, Dict, Optional import cv2 import numpy as np import numpy.typing as npt +from PIL import Image from vllm.inputs.registry import InputContext from vllm.logger import init_logger from vllm.transformers_utils.processor import get_video_processor from vllm.transformers_utils.tokenizer import get_tokenizer -from vllm.utils import is_list_of +from vllm.utils import PlaceholderModule, is_list_of -from .base import MultiModalData -from .image import ImagePlugin +from .base import MediaIO, MultiModalData +from .image import ImageMediaIO, ImagePlugin from .inputs import MultiModalKwargs, VideoItem if TYPE_CHECKING: from vllm.config import ModelConfig +try: + import decord +except ImportError: + decord = PlaceholderModule("decord") # type: ignore[assignment] + logger = init_logger(__name__) cached_get_video_processor = lru_cache(get_video_processor) @@ -107,3 +116,73 @@ def sample_frames_from_video(frames: npt.NDArray, frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int) sampled_frames = frames[frame_indices, ...] return sampled_frames + + +class VideoMediaIO(MediaIO[npt.NDArray]): + + def __init__( + self, + image_io: ImageMediaIO, + *, + num_frames: int = 32, + ) -> None: + super().__init__() + + self.image_io = image_io + self.num_frames = num_frames + + def load_bytes(self, data: bytes) -> npt.NDArray: + vr = decord.VideoReader(BytesIO(data), num_threads=1) + total_frame_num = len(vr) + + num_frames = self.num_frames + if total_frame_num > num_frames: + uniform_sampled_frames = np.linspace(0, + total_frame_num - 1, + num_frames, + dtype=int) + frame_idx = uniform_sampled_frames.tolist() + else: + frame_idx = list(range(0, total_frame_num)) + + return vr.get_batch(frame_idx).asnumpy() + + def load_base64(self, media_type: str, data: str) -> npt.NDArray: + if media_type.lower() == "video/jpeg": + load_frame = partial( + self.image_io.load_base64, + "image/jpeg", + ) + + return np.stack([ + np.array(load_frame(frame_data)) + for frame_data in data.split(",") + ]) + + return self.load_bytes(base64.b64decode(data)) + + def load_file(self, filepath: Path) -> npt.NDArray: + with filepath.open("rb") as f: + data = f.read() + + return self.load_bytes(data) + + def encode_base64( + self, + media: npt.NDArray, + *, + video_format: str = "JPEG", + ) -> str: + video = media + + if video_format == "JPEG": + encode_frame = partial( + self.image_io.encode_base64, + image_format=video_format, + ) + + return ",".join( + encode_frame(Image.fromarray(frame)) for frame in video) + + msg = "Only JPEG format is supported for now." + raise NotImplementedError(msg) diff --git a/vllm/transformers_utils/processor.py b/vllm/transformers_utils/processor.py index f1523667b0466..b12cc83a22970 100644 --- a/vllm/transformers_utils/processor.py +++ b/vllm/transformers_utils/processor.py @@ -1,25 +1,31 @@ from functools import lru_cache from typing import Any, cast +from transformers.processing_utils import ProcessorMixin + def get_processor( processor_name: str, *args: Any, trust_remote_code: bool = False, + processor_cls: type[ProcessorMixin] = ProcessorMixin, **kwargs: Any, ): """Load a processor for the given model name via HuggingFace.""" # don't put this import at the top level # it will call torch.cuda.device_count() from transformers import AutoProcessor - from transformers.processing_utils import ProcessorMixin + + processor_factory = (AutoProcessor + if processor_cls == ProcessorMixin else processor_cls) try: - processor = AutoProcessor.from_pretrained( + processor = processor_factory.from_pretrained( processor_name, *args, trust_remote_code=trust_remote_code, - **kwargs) + **kwargs, + ) except ValueError as e: # If the error pertains to the processor class not existing or not # currently being imported, suggest using the --trust-remote-code flag. diff --git a/vllm/utils.py b/vllm/utils.py index 3d198887021dc..5eb4e8c4180c4 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -25,11 +25,11 @@ import weakref from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task from collections import OrderedDict, UserDict, defaultdict -from collections.abc import Iterable, Mapping +from collections.abc import Hashable, Iterable, Mapping from dataclasses import dataclass, field from functools import lru_cache, partial, wraps from typing import (TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, - Dict, Generator, Generic, Hashable, List, Literal, + Dict, Generator, Generic, List, Literal, NamedTuple, Optional, Tuple, Type, TypeVar, Union, overload) from uuid import uuid4 @@ -194,13 +194,29 @@ def reset(self) -> None: self.counter = 0 +class CacheInfo(NamedTuple): + hits: int + total: int + + @property + def hit_ratio(self) -> float: + if self.total == 0: + return 0 + + return self.hits / self.total + + class LRUCache(Generic[_K, _V]): + """Note: This class is not thread safe!""" def __init__(self, capacity: int) -> None: self.cache = OrderedDict[_K, _V]() self.pinned_items = set[_K]() self.capacity = capacity + self._hits = 0 + self._total = 0 + def __contains__(self, key: _K) -> bool: return key in self.cache @@ -218,6 +234,9 @@ def __setitem__(self, key: _K, value: _V) -> None: def __delitem__(self, key: _K) -> None: self.pop(key) + def stat(self) -> CacheInfo: + return CacheInfo(hits=self._hits, total=self._total) + def touch(self, key: _K) -> None: self.cache.move_to_end(key) @@ -226,8 +245,12 @@ def get(self, key: _K, default: Optional[_V] = None) -> Optional[_V]: if key in self.cache: value = self.cache[key] self.cache.move_to_end(key) + + self._hits += 1 else: value = default + + self._total += 1 return value def put(self, key: _K, value: _V) -> None: