Skip to content

Commit

Permalink
hasher
Browse files Browse the repository at this point in the history
Signed-off-by: Roger Wang <[email protected]>
  • Loading branch information
ywang96 committed Jan 3, 2025
1 parent 34ec194 commit 9f19629
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 65 deletions.
2 changes: 2 additions & 0 deletions vllm/multimodal/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .base import MultiModalPlaceholderMap, MultiModalPlugin
from .hasher import MultiModalHasher
from .inputs import (BatchedTensorInputs, ModalityData, MultiModalDataBuiltins,
MultiModalDataDict, MultiModalHashDict, MultiModalKwargs,
MultiModalPlaceholderDict, NestedTensors)
Expand All @@ -19,6 +20,7 @@
"MultiModalDataBuiltins",
"MultiModalDataDict",
"MultiModalHashDict",
"MultiModalHasher",
"MultiModalKwargs",
"MultiModalPlaceholderDict",
"MultiModalPlaceholderMap",
Expand Down
64 changes: 64 additions & 0 deletions vllm/multimodal/hasher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import pickle
from blake3 import blake3
import torch
import numpy as np
from PIL import Image
from typing import Iterable

from vllm.logger import init_logger

logger = init_logger(__name__)

class MultiModalHasher:

@classmethod
def serialize_item(cls, obj: object) -> bytes:
# Simple cases
if isinstance(obj, str):
return obj.encode("utf-8")
if isinstance(obj, bytes):
return obj
if isinstance(obj, Image.Image):
return obj.tobytes()

# Convertible to NumPy arrays
if isinstance(obj, torch.Tensor):
obj = obj.numpy()
if isinstance(obj, (int, float)):
obj = np.array(obj)
if isinstance(obj, np.ndarray):
return obj.tobytes()

logger.warning(
"No serialization method found for %s. "
"Falling back to pickle.", type(obj))

return pickle.dumps(obj)

@classmethod
def item_to_bytes(cls,
key: str,
obj: object,
) -> Iterable[tuple[bytes, bytes]]:
# Recursive cases
if isinstance(obj, (list, tuple)):
for i, elem in enumerate(obj):
yield from cls.item_to_bytes(f"{key}.{i}", elem)
elif isinstance(obj, dict):
for k, v in obj.items():
yield from cls.item_to_bytes(f"{key}.{k}", v)
else:
key_bytes = cls.serialize_item(key)
value_bytes = cls.serialize_item(obj)
yield key_bytes, value_bytes

@classmethod
def hash_kwargs(cls, **kwargs: object) -> str:
hasher = blake3()

for k, v in kwargs.items():
for k_bytes, v_bytes in cls.item_to_bytes(k, v):
hasher.update(k_bytes)
hasher.update(v_bytes)

return hasher.hexdigest()
21 changes: 11 additions & 10 deletions vllm/multimodal/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
from vllm.transformers_utils.tokenizer import AnyTokenizer, encode_tokens
from vllm.utils import LRUCache, flatten_2d_lists, full_groupby

from .hasher import MultiModalHasher
from .inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs,
MultiModalKwargsItem, PlaceholderRange)
from .parse import MultiModalDataItems, MultiModalDataParser
from .utils import hash_kwargs

logger = init_logger(__name__)

Expand Down Expand Up @@ -508,9 +508,9 @@ def get(
"""
self._maybe_log_cache_stats()

cache_key = hash_kwargs(model_id=model_id,
**{modality: input_item},
**input_kwargs)
cache_key = MultiModalHasher.hash_kwargs(model_id=model_id,
**{modality: input_item},
**input_kwargs)
return self._cache.get(cache_key)

def put(
Expand All @@ -525,9 +525,9 @@ def put(
Put a processed multi-modal item into the cache
according to its dependencies (see :meth:`get`).
"""
cache_key = hash_kwargs(model_id=model_id,
**{modality: input_item},
**input_kwargs)
cache_key = MultiModalHasher.hash_kwargs(model_id=model_id,
**{modality: input_item},
**input_kwargs)
self._cache.put(cache_key, output_kwargs)


Expand Down Expand Up @@ -952,9 +952,10 @@ def apply(
model_id = self.ctx.model_config.model
mm_hashes = {
modality: [
hash_kwargs(model_id=model_id,
**{modality: item},
**hf_processor_mm_kwargs) for item in items
MultiModalHasher.hash_kwargs(model_id=model_id,
**{modality: item},
**hf_processor_mm_kwargs)
for item in items
]
for modality, items in mm_items.items()
}
Expand Down
56 changes: 1 addition & 55 deletions vllm/multimodal/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import pickle
from functools import lru_cache
from pathlib import Path
from typing import TYPE_CHECKING, Iterable, Optional, TypeVar, Union
from typing import TYPE_CHECKING, Optional, TypeVar, Union
from urllib.parse import ParseResult, urlparse

import numpy as np
import numpy.typing as npt
import torch
from blake3 import blake3
from PIL import Image

import vllm.envs as envs
Expand Down Expand Up @@ -256,58 +254,6 @@ async def fetch_video_async(
fetch_video = global_media_connector.fetch_video


def serialize_item(obj: object) -> bytes:
# Simple cases
if isinstance(obj, str):
return obj.encode("utf-8")
if isinstance(obj, bytes):
return obj
if isinstance(obj, Image.Image):
return obj.tobytes()

# Convertible to NumPy arrays
if isinstance(obj, torch.Tensor):
obj = obj.numpy()
if isinstance(obj, (int, float)):
obj = np.array(obj)
if isinstance(obj, np.ndarray):
return obj.tobytes()

logger.warning(
"No serialization method found for %s. "
"Falling back to pickle.", type(obj))

return pickle.dumps(obj)


def item_to_bytes(
key: str,
obj: object,
) -> Iterable[tuple[bytes, bytes]]:
# Recursive cases
if isinstance(obj, (list, tuple)):
for i, elem in enumerate(obj):
yield from item_to_bytes(f"{key}.{i}", elem)
elif isinstance(obj, dict):
for k, v in obj.items():
yield from item_to_bytes(f"{key}.{k}", v)
else:
key_bytes = serialize_item(key)
value_bytes = serialize_item(obj)
yield key_bytes, value_bytes


def hash_kwargs(**kwargs: object) -> str:
hasher = blake3()

for k, v in kwargs.items():
for k_bytes, v_bytes in item_to_bytes(k, v):
hasher.update(k_bytes)
hasher.update(v_bytes)

return hasher.hexdigest()


def encode_audio_base64(
audio: np.ndarray,
sampling_rate: int,
Expand Down

0 comments on commit 9f19629

Please sign in to comment.