diff --git a/server/tests/models/test_bloom.py b/server/tests/models/test_bloom.py index 66df708adcb..32ee6686b6b 100644 --- a/server/tests/models/test_bloom.py +++ b/server/tests/models/test_bloom.py @@ -29,6 +29,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters): return generate_pb2.Request( id=0, inputs="Test", + input_chunks=generate_pb2.Input(chunks=[generate_pb2.InputChunk(text="Test")]), prefill_logprobs=True, truncate=100, parameters=default_pb_parameters, diff --git a/server/tests/models/test_causal_lm.py b/server/tests/models/test_causal_lm.py index 250fa354b47..6e6463bc948 100644 --- a/server/tests/models/test_causal_lm.py +++ b/server/tests/models/test_causal_lm.py @@ -25,6 +25,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters): return generate_pb2.Request( id=0, inputs="Test", + input_chunks=generate_pb2.Input(chunks=[generate_pb2.InputChunk(text="Test")]), prefill_logprobs=True, truncate=100, parameters=default_pb_parameters, diff --git a/server/tests/models/test_santacoder.py b/server/tests/models/test_santacoder.py index 1e40e766d67..cb2622d9b53 100644 --- a/server/tests/models/test_santacoder.py +++ b/server/tests/models/test_santacoder.py @@ -15,6 +15,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters): return generate_pb2.Request( id=0, inputs="def", + input_chunks=generate_pb2.Input(chunks=[generate_pb2.InputChunk(text="def")]), prefill_logprobs=True, truncate=100, parameters=default_pb_parameters, @@ -32,6 +33,13 @@ def default_fim_pb_request(default_pb_parameters, default_pb_stop_parameters): return generate_pb2.Request( id=0, inputs="defworld", + input_chunks=generate_pb2.Input( + chunks=[ + generate_pb2.InputChunk( + text="defworld" + ) + ] + ), prefill_logprobs=True, truncate=100, parameters=default_pb_parameters, diff --git a/server/tests/models/test_seq2seq_lm.py b/server/tests/models/test_seq2seq_lm.py index 735ab5eb562..943c3b0820d 100644 --- a/server/tests/models/test_seq2seq_lm.py +++ b/server/tests/models/test_seq2seq_lm.py @@ -28,6 +28,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters): return generate_pb2.Request( id=0, inputs="Test", + input_chunks=generate_pb2.Input(chunks=[generate_pb2.InputChunk(text="Test")]), prefill_logprobs=True, truncate=100, parameters=default_pb_parameters, diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 81a02163dc9..e896c831bd3 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -7,6 +7,7 @@ from typing import Optional, Tuple, List, Type, Dict from text_generation_server.models import Model +from text_generation_server.utils.chunks import concat_text_chunks from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.models.types import ( Batch, @@ -86,7 +87,8 @@ def from_pb( max_decode_tokens = 0 for i, r in enumerate(pb.requests): requests_idx_mapping[r.id] = i - inputs.append(r.inputs) + inputs.append(concat_text_chunks(r.input_chunks.chunks)) + next_token_choosers.append( NextTokenChooser.from_pb(r.parameters, device, tokenizer) ) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index d8c8838cfaa..acf77b09f5c 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -11,9 +11,10 @@ from dataclasses import dataclass from opentelemetry import trace from transformers import PreTrainedTokenizerBase -from typing import Optional, Tuple, List, Type, Dict +from typing import Iterable, Optional, Tuple, List, Type, Dict from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE +from text_generation_server.utils.chunks import concat_text_chunks from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.models import Model from text_generation_server.utils.tokens import batch_top_tokens @@ -127,11 +128,13 @@ def to_pb(self) -> generate_pb2.CachedBatch: ) @classmethod - def batch_tokenized_inputs(cls, requests, tokenizer): + def batch_tokenized_inputs( + cls, requests: Iterable[generate_pb2.Request], tokenizer + ): batch_inputs = [] max_truncation = 0 for r in requests: - batch_inputs.append(r.inputs) + batch_inputs.append(concat_text_chunks(r.input_chunks.chunks)) max_truncation = max(max_truncation, r.truncate) batch_tokenized_inputs = tokenizer( diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index 4656fd45e84..d0f2b9154e2 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -20,6 +20,7 @@ weight_files, Weights, ) +from text_generation_server.utils.chunks import concat_text_chunks # CREDIT: Papers with code => https://github.com/paperswithcode/galai/blob/main/galai/utils.py @@ -91,7 +92,9 @@ def from_pb( for i, r in enumerate(pb.requests): requests_idx_mapping[r.id] = i # Add escape_custom_split_sequence to the CausalLMBatch logic - inputs.append(escape_custom_split_sequence(r.inputs)) + inputs.append( + escape_custom_split_sequence(concat_text_chunks(r.input_chunks.chunks)) + ) next_token_choosers.append( NextTokenChooser.from_pb(r.parameters, device, tokenizer) ) diff --git a/server/text_generation_server/models/idefics_causal_lm.py b/server/text_generation_server/models/idefics_causal_lm.py index e78a9655a5b..f507d669936 100644 --- a/server/text_generation_server/models/idefics_causal_lm.py +++ b/server/text_generation_server/models/idefics_causal_lm.py @@ -1,4 +1,5 @@ -import torch +from io import BytesIO +from PIL import Image import torch import time @@ -21,11 +22,6 @@ ) from text_generation_server.pb import generate_pb2 from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling -from text_generation_server.models.vlm_causal_lm import split - -import re - -IMAGES = re.compile(r"!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)") tracer = trace.get_tracer(__name__) @@ -109,7 +105,7 @@ def from_pb_processor( max_decode_tokens = 0 for i, r in enumerate(pb.requests): requests_idx_mapping[r.id] = i - inputs.append(r.inputs) + inputs.append(r.input_chunks.chunks) next_token_choosers.append( NextTokenChooser.from_pb(r.parameters, device, tokenizer) ) @@ -128,8 +124,15 @@ def from_pb_processor( for inp in inputs: # Each input is encoded into a list, where each element of this input list is either a string or a URL prompt = [] - for chunk in split(inp): - prompt.append(chunk["content"]) + for chunk in inp: + chunk_type = chunk.WhichOneof("chunk") + if chunk_type == "text": + prompt.append(chunk.text) + elif chunk_type == "image": + image = Image.open(BytesIO(chunk.image.data)) + prompt.append(image) + else: + raise RuntimeError(f"Invalid chunk type {chunk_type}") prompts.append(prompt) # The processor replaces the call to tokenizer, and diff --git a/server/text_generation_server/models/mamba.py b/server/text_generation_server/models/mamba.py index d9f905906ac..3133a137d9a 100644 --- a/server/text_generation_server/models/mamba.py +++ b/server/text_generation_server/models/mamba.py @@ -27,6 +27,7 @@ Generation, GeneratedText, ) +from text_generation_server.utils.chunks import concat_text_chunks from text_generation_server.utils.tokens import batch_top_tokens, Sampling from dataclasses import dataclass from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling @@ -139,7 +140,7 @@ def from_pb( max_decode_tokens = 0 for i, r in enumerate(pb.requests): requests_idx_mapping[r.id] = i - inputs.append(r.inputs) + inputs.append(concat_text_chunks(r.input_chunks.chunks)) next_token_choosers.append( NextTokenChooser.from_pb(r.parameters, device, tokenizer) ) diff --git a/server/text_generation_server/models/pali_gemma.py b/server/text_generation_server/models/pali_gemma.py index d94b9526582..e883ce02e18 100644 --- a/server/text_generation_server/models/pali_gemma.py +++ b/server/text_generation_server/models/pali_gemma.py @@ -1,55 +1,48 @@ +from io import BytesIO +from PIL import Image import torch import torch.distributed from opentelemetry import trace -from typing import Optional, Tuple +from typing import Iterable, Optional, Tuple from text_generation_server.models.vlm_causal_lm import ( VlmCausalLM, VlmCausalLMBatch, image_text_replacement, - load_data_uri, - split, ) from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import ( PaliGemmaForConditionalGeneration, ) -from transformers import AutoProcessor, AutoConfig, AutoImageProcessor +from transformers import AutoProcessor, AutoConfig + +from text_generation_server.pb.generate_pb2 import Request tracer = trace.get_tracer(__name__) class PaliGemmaBatch(VlmCausalLMBatch): @classmethod - def batch_tokenized_inputs(cls, requests, tokenizer, processor, config): + def batch_tokenized_inputs( + cls, requests: Iterable[Request], tokenizer, processor, config + ): batch_inputs = [] image_inputs = [] max_truncation = 0 for r in requests: - chunks = split(r.inputs) full_text = "" image_id = 0 - for chunk in chunks: - if chunk["type"] == "text": - full_text += "" + chunk["content"] + "\n" - elif chunk["type"] == "image": - image = chunk["content"] - # Should never receive URLs anymore, processing should be done - # On the rust layer. - # This avoid making n queries per TP - # if image.startswith("https://") or image.startswith("http://"): - # image = processor.image_processor.fetch_images(image) - if image.startswith("data:"): - image = load_data_uri(image) - else: - raise RuntimeError( - "Cannot process input image not starting with data:" - ) + for chunk in r.input_chunks.chunks: + chunk_type = chunk.WhichOneof("chunk") + if chunk_type == "text": + full_text += "" + chunk.text + "\n" + elif chunk_type == "image": + image = Image.open(BytesIO(chunk.image.data)) # TODO do_convert_RGB should be on by default ? image = image.convert("RGB") image_input = processor.image_processor(image, return_tensors="pt") full_text += image_text_replacement(image_input, config, image_id) image_inputs.append(image_input) else: - raise RuntimeError(f"Invalid chunk type {chunk['type']}") + raise RuntimeError(f"Invalid chunk type {chunk_type}") batch_inputs.append(full_text) max_truncation = max(max_truncation, r.truncate) diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 6a0c812f277..3bd095564c3 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -6,6 +6,7 @@ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase from typing import Optional, Tuple, List, Type, Dict +from text_generation_server.utils.chunks import concat_text_chunks from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.models import Model from text_generation_server.models.types import ( @@ -93,7 +94,7 @@ def from_pb( padding_right_offset = 0 max_decode_tokens = 0 for i, r in enumerate(pb.requests): - inputs.append(r.inputs) + inputs.append(concat_text_chunks(r.input_chunks.chunks)) requests_idx_mapping[r.id] = i decoder_input_lengths.append(1) next_token_choosers.append( diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 92d790709e7..59a6fab1cd0 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -1,12 +1,9 @@ -import re import torch -import math from PIL import Image from io import BytesIO -import base64 from opentelemetry import trace -from typing import Optional, Tuple, List, Type, Dict +from typing import Iterable, Optional, Tuple, List, Type, Dict from transformers import PreTrainedTokenizerBase from transformers.image_processing_utils import select_best_resolution @@ -18,25 +15,6 @@ tracer = trace.get_tracer(__name__) -IMAGES = re.compile(r"!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)") - - -def split(string) -> List[Dict[str, str]]: - parts = [] - cursor = 0 - for pattern in IMAGES.finditer(string): - start = pattern.start() - if start != cursor: - parts.append({"type": "text", "content": string[cursor:start]}) - - parts.append({"type": "image", "content": pattern.group(1)}) - cursor = pattern.end() - - if cursor != len(string): - parts.append({"type": "text", "content": string[cursor:]}) - - return parts - def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): """ @@ -129,13 +107,6 @@ def get_number_of_features(height: int, width: int, config) -> int: return unpadded_features + newline_features + base_features -def load_data_uri(image_uri: str) -> Image.Image: - image_uri = image_uri.split(",")[-1] - content = base64.b64decode(image_uri) - image = Image.open(BytesIO(content)) - return image - - class VlmCausalLMBatch(FlashCausalLMBatch): pixel_values: Optional[List[torch.Tensor]] pixel_attention_mask: Optional[List[torch.Tensor]] @@ -159,35 +130,26 @@ def filter(self, request_ids: List[int]): return batch @classmethod - def batch_tokenized_inputs(cls, requests, tokenizer, processor, config): + def batch_tokenized_inputs( + cls, requests: Iterable[generate_pb2.Request], tokenizer, processor, config + ): batch_inputs = [] image_inputs = [] max_truncation = 0 for r in requests: - chunks = split(r.inputs) full_text = "" image_id = 0 - for chunk in chunks: - if chunk["type"] == "text": - full_text += chunk["content"] - elif chunk["type"] == "image": - image = chunk["content"] - # Should never receive URLs anymore, processing should be done - # On the rust layer. - # This avoid making n queries per TP - # if image.startswith("https://") or image.startswith("http://"): - # image = processor.image_processor.fetch_images(image) - if image.startswith("data:"): - image = load_data_uri(image) - else: - raise RuntimeError( - "Cannot process input image not starting with data:" - ) + for chunk in r.input_chunks.chunks: + chunk_type = chunk.WhichOneof("chunk") + if chunk_type == "text": + full_text += chunk.text + elif chunk_type == "image": + image = Image.open(BytesIO(chunk.image.data)) image_input = processor.image_processor(image, return_tensors="pt") full_text += image_text_replacement(image_input, config, image_id) image_inputs.append(image_input) else: - raise RuntimeError(f"Invalid chunk type {chunk['type']}") + raise RuntimeError(f"Invalid chunk type {chunk_type}") batch_inputs.append(full_text) max_truncation = max(max_truncation, r.truncate) diff --git a/server/text_generation_server/utils/chunks.py b/server/text_generation_server/utils/chunks.py new file mode 100644 index 00000000000..73962ea39e1 --- /dev/null +++ b/server/text_generation_server/utils/chunks.py @@ -0,0 +1,27 @@ +from typing import Iterable + +from loguru import logger + +from text_generation_server.pb import generate_pb2 + + +def concat_text_chunks(chunks: Iterable[generate_pb2.InputChunk]) -> str: + """ + Concatenate text in text chunks. Non-text chunks are dropped. + """ + text = None + for chunk in chunks: + chunk_type = chunk.WhichOneof("chunk") + if chunk_type == "text": + if text is None: + text = chunk.text + else: + raise NotImplementedError("Request contained more than one text chunk") + else: + # We cannot reject this, e.g. warmup sends an image chunk. + logger.debug(f"Encountered non-text chunk type {chunk_type}") + + if text is None: + raise NotImplementedError("Request without a text chunk") + + return text