Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

server: use chunked inputs #1985

Merged
merged 1 commit into from
Jun 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ jobs:
cancel-in-progress: true
runs-on: [self-hosted, nvidia-gpu , multi-gpu, 4-a10, ci]
strategy:
matrix:
matrix:
include:
- name: "cuda"
label: ""
Expand Down
1 change: 1 addition & 0 deletions server/tests/models/test_bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request(
id=0,
inputs="Test",
input_chunks=generate_pb2.Input(chunks=[generate_pb2.InputChunk(text="Test")]),
prefill_logprobs=True,
truncate=100,
parameters=default_pb_parameters,
Expand Down
1 change: 1 addition & 0 deletions server/tests/models/test_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request(
id=0,
inputs="Test",
input_chunks=generate_pb2.Input(chunks=[generate_pb2.InputChunk(text="Test")]),
prefill_logprobs=True,
truncate=100,
parameters=default_pb_parameters,
Expand Down
8 changes: 8 additions & 0 deletions server/tests/models/test_santacoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request(
id=0,
inputs="def",
input_chunks=generate_pb2.Input(chunks=[generate_pb2.InputChunk(text="def")]),
prefill_logprobs=True,
truncate=100,
parameters=default_pb_parameters,
Expand All @@ -32,6 +33,13 @@ def default_fim_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request(
id=0,
inputs="<fim-prefix>def<fim-suffix>world<fim-middle>",
input_chunks=generate_pb2.Input(
chunks=[
generate_pb2.InputChunk(
text="<fim-prefix>def<fim-suffix>world<fim-middle>"
)
]
),
prefill_logprobs=True,
truncate=100,
parameters=default_pb_parameters,
Expand Down
1 change: 1 addition & 0 deletions server/tests/models/test_seq2seq_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request(
id=0,
inputs="Test",
input_chunks=generate_pb2.Input(chunks=[generate_pb2.InputChunk(text="Test")]),
prefill_logprobs=True,
truncate=100,
parameters=default_pb_parameters,
Expand Down
4 changes: 3 additions & 1 deletion server/text_generation_server/models/causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Optional, Tuple, List, Type, Dict

from text_generation_server.models import Model
from text_generation_server.utils.chunks import concat_text_chunks
from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.models.types import (
Batch,
Expand Down Expand Up @@ -86,7 +87,8 @@ def from_pb(
max_decode_tokens = 0
for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i
inputs.append(r.inputs)
inputs.append(concat_text_chunks(r.input_chunks.chunks))

next_token_choosers.append(
NextTokenChooser.from_pb(r.parameters, device, tokenizer)
)
Expand Down
9 changes: 6 additions & 3 deletions server/text_generation_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
from dataclasses import dataclass
from opentelemetry import trace
from transformers import PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type, Dict
from typing import Iterable, Optional, Tuple, List, Type, Dict

from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from text_generation_server.utils.chunks import concat_text_chunks
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models import Model
from text_generation_server.utils.tokens import batch_top_tokens
Expand Down Expand Up @@ -127,11 +128,13 @@ def to_pb(self) -> generate_pb2.CachedBatch:
)

@classmethod
def batch_tokenized_inputs(cls, requests, tokenizer):
def batch_tokenized_inputs(
cls, requests: Iterable[generate_pb2.Request], tokenizer
):
batch_inputs = []
max_truncation = 0
for r in requests:
batch_inputs.append(r.inputs)
batch_inputs.append(concat_text_chunks(r.input_chunks.chunks))
max_truncation = max(max_truncation, r.truncate)

batch_tokenized_inputs = tokenizer(
Expand Down
5 changes: 4 additions & 1 deletion server/text_generation_server/models/galactica.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
weight_files,
Weights,
)
from text_generation_server.utils.chunks import concat_text_chunks

# CREDIT: Papers with code => https://github.com/paperswithcode/galai/blob/main/galai/utils.py

Expand Down Expand Up @@ -91,7 +92,9 @@ def from_pb(
for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i
# Add escape_custom_split_sequence to the CausalLMBatch logic
inputs.append(escape_custom_split_sequence(r.inputs))
inputs.append(
escape_custom_split_sequence(concat_text_chunks(r.input_chunks.chunks))
)
next_token_choosers.append(
NextTokenChooser.from_pb(r.parameters, device, tokenizer)
)
Expand Down
21 changes: 12 additions & 9 deletions server/text_generation_server/models/idefics_causal_lm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
from io import BytesIO
from PIL import Image
import torch
import time

Expand All @@ -21,11 +22,6 @@
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
from text_generation_server.models.vlm_causal_lm import split

import re

IMAGES = re.compile(r"!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)")


tracer = trace.get_tracer(__name__)
Expand Down Expand Up @@ -109,7 +105,7 @@ def from_pb_processor(
max_decode_tokens = 0
for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i
inputs.append(r.inputs)
inputs.append(r.input_chunks.chunks)
next_token_choosers.append(
NextTokenChooser.from_pb(r.parameters, device, tokenizer)
)
Expand All @@ -128,8 +124,15 @@ def from_pb_processor(
for inp in inputs:
# Each input is encoded into a list, where each element of this input list is either a string or a URL
prompt = []
for chunk in split(inp):
prompt.append(chunk["content"])
for chunk in inp:
chunk_type = chunk.WhichOneof("chunk")
if chunk_type == "text":
prompt.append(chunk.text)
elif chunk_type == "image":
image = Image.open(BytesIO(chunk.image.data))
prompt.append(image)
else:
raise RuntimeError(f"Invalid chunk type {chunk_type}")
prompts.append(prompt)

# The processor replaces the call to tokenizer, and
Expand Down
3 changes: 2 additions & 1 deletion server/text_generation_server/models/mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
Generation,
GeneratedText,
)
from text_generation_server.utils.chunks import concat_text_chunks
from text_generation_server.utils.tokens import batch_top_tokens, Sampling
from dataclasses import dataclass
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
Expand Down Expand Up @@ -139,7 +140,7 @@ def from_pb(
max_decode_tokens = 0
for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i
inputs.append(r.inputs)
inputs.append(concat_text_chunks(r.input_chunks.chunks))
next_token_choosers.append(
NextTokenChooser.from_pb(r.parameters, device, tokenizer)
)
Expand Down
39 changes: 16 additions & 23 deletions server/text_generation_server/models/pali_gemma.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,48 @@
from io import BytesIO
from PIL import Image
import torch
import torch.distributed
from opentelemetry import trace
from typing import Optional, Tuple
from typing import Iterable, Optional, Tuple
from text_generation_server.models.vlm_causal_lm import (
VlmCausalLM,
VlmCausalLMBatch,
image_text_replacement,
load_data_uri,
split,
)
from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
PaliGemmaForConditionalGeneration,
)
from transformers import AutoProcessor, AutoConfig, AutoImageProcessor
from transformers import AutoProcessor, AutoConfig

from text_generation_server.pb.generate_pb2 import Request

tracer = trace.get_tracer(__name__)


class PaliGemmaBatch(VlmCausalLMBatch):
@classmethod
def batch_tokenized_inputs(cls, requests, tokenizer, processor, config):
def batch_tokenized_inputs(
cls, requests: Iterable[Request], tokenizer, processor, config
):
batch_inputs = []
image_inputs = []
max_truncation = 0
for r in requests:
chunks = split(r.inputs)
full_text = ""
image_id = 0
for chunk in chunks:
if chunk["type"] == "text":
full_text += "<bos>" + chunk["content"] + "\n"
elif chunk["type"] == "image":
image = chunk["content"]
# Should never receive URLs anymore, processing should be done
# On the rust layer.
# This avoid making n queries per TP
# if image.startswith("https://") or image.startswith("http://"):
# image = processor.image_processor.fetch_images(image)
if image.startswith("data:"):
image = load_data_uri(image)
else:
raise RuntimeError(
"Cannot process input image not starting with data:"
)
for chunk in r.input_chunks.chunks:
chunk_type = chunk.WhichOneof("chunk")
if chunk_type == "text":
full_text += "<bos>" + chunk.text + "\n"
elif chunk_type == "image":
image = Image.open(BytesIO(chunk.image.data))
# TODO do_convert_RGB should be on by default ?
image = image.convert("RGB")
image_input = processor.image_processor(image, return_tensors="pt")
full_text += image_text_replacement(image_input, config, image_id)
image_inputs.append(image_input)
else:
raise RuntimeError(f"Invalid chunk type {chunk['type']}")
raise RuntimeError(f"Invalid chunk type {chunk_type}")

batch_inputs.append(full_text)
max_truncation = max(max_truncation, r.truncate)
Expand Down
3 changes: 2 additions & 1 deletion server/text_generation_server/models/seq2seq_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type, Dict

from text_generation_server.utils.chunks import concat_text_chunks
from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.models import Model
from text_generation_server.models.types import (
Expand Down Expand Up @@ -93,7 +94,7 @@ def from_pb(
padding_right_offset = 0
max_decode_tokens = 0
for i, r in enumerate(pb.requests):
inputs.append(r.inputs)
inputs.append(concat_text_chunks(r.input_chunks.chunks))
requests_idx_mapping[r.id] = i
decoder_input_lengths.append(1)
next_token_choosers.append(
Expand Down
60 changes: 11 additions & 49 deletions server/text_generation_server/models/vlm_causal_lm.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import re
import torch
import math
from PIL import Image
from io import BytesIO
import base64

from opentelemetry import trace
from typing import Optional, Tuple, List, Type, Dict
from typing import Iterable, Optional, Tuple, List, Type, Dict

from transformers import PreTrainedTokenizerBase
from transformers.image_processing_utils import select_best_resolution
Expand All @@ -18,25 +15,6 @@

tracer = trace.get_tracer(__name__)

IMAGES = re.compile(r"!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)")


def split(string) -> List[Dict[str, str]]:
parts = []
cursor = 0
for pattern in IMAGES.finditer(string):
start = pattern.start()
if start != cursor:
parts.append({"type": "text", "content": string[cursor:start]})

parts.append({"type": "image", "content": pattern.group(1)})
cursor = pattern.end()

if cursor != len(string):
parts.append({"type": "text", "content": string[cursor:]})

return parts


def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
"""
Expand Down Expand Up @@ -129,13 +107,6 @@ def get_number_of_features(height: int, width: int, config) -> int:
return unpadded_features + newline_features + base_features


def load_data_uri(image_uri: str) -> Image.Image:
image_uri = image_uri.split(",")[-1]
content = base64.b64decode(image_uri)
image = Image.open(BytesIO(content))
return image


class VlmCausalLMBatch(FlashCausalLMBatch):
pixel_values: Optional[List[torch.Tensor]]
pixel_attention_mask: Optional[List[torch.Tensor]]
Expand All @@ -159,35 +130,26 @@ def filter(self, request_ids: List[int]):
return batch

@classmethod
def batch_tokenized_inputs(cls, requests, tokenizer, processor, config):
def batch_tokenized_inputs(
cls, requests: Iterable[generate_pb2.Request], tokenizer, processor, config
):
batch_inputs = []
image_inputs = []
max_truncation = 0
for r in requests:
chunks = split(r.inputs)
full_text = ""
image_id = 0
for chunk in chunks:
if chunk["type"] == "text":
full_text += chunk["content"]
elif chunk["type"] == "image":
image = chunk["content"]
# Should never receive URLs anymore, processing should be done
# On the rust layer.
# This avoid making n queries per TP
# if image.startswith("https://") or image.startswith("http://"):
# image = processor.image_processor.fetch_images(image)
if image.startswith("data:"):
image = load_data_uri(image)
else:
raise RuntimeError(
"Cannot process input image not starting with data:"
)
for chunk in r.input_chunks.chunks:
chunk_type = chunk.WhichOneof("chunk")
if chunk_type == "text":
full_text += chunk.text
elif chunk_type == "image":
image = Image.open(BytesIO(chunk.image.data))
image_input = processor.image_processor(image, return_tensors="pt")
full_text += image_text_replacement(image_input, config, image_id)
image_inputs.append(image_input)
else:
raise RuntimeError(f"Invalid chunk type {chunk['type']}")
raise RuntimeError(f"Invalid chunk type {chunk_type}")

batch_inputs.append(full_text)
max_truncation = max(max_truncation, r.truncate)
Expand Down
27 changes: 27 additions & 0 deletions server/text_generation_server/utils/chunks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from typing import Iterable

from loguru import logger

from text_generation_server.pb import generate_pb2


def concat_text_chunks(chunks: Iterable[generate_pb2.InputChunk]) -> str:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this method only to be future proof or is there a way today to have multiple text chunks?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAIK we can currently only have multiple text chunks in a VLM models, so this was indeed only to future proof.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then mb we should take [0] and crash if len > 1 with an unreachable?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, we do need to iterate over chunks, because we are sending image chunks unconditionally during warmup, even for text-only models:

The current approach seems more robust? What do you think about logging a warning when len(texts) > 1?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to:

  • Fail when there is more than one text chunk.
  • Fail when there is no text chunk.
  • Log at debug-level when there is a non-text chunk (only log because e.g. warmup sends an image chunk).

"""
Concatenate text in text chunks. Non-text chunks are dropped.
"""
text = None
for chunk in chunks:
chunk_type = chunk.WhichOneof("chunk")
if chunk_type == "text":
if text is None:
text = chunk.text
else:
raise NotImplementedError("Request contained more than one text chunk")
else:
# We cannot reject this, e.g. warmup sends an image chunk.
logger.debug(f"Encountered non-text chunk type {chunk_type}")

if text is None:
raise NotImplementedError("Request without a text chunk")

return text
Loading