Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: improve qwen2-vl startup #2802

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backends/client/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,6 @@ impl ChunksToString for Vec<InputChunk> {
}
}

static WARMUP_IMAGE_BASE64 :&str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=";
static WARMUP_IMAGE_BASE64: &str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=";

pub type Result<T> = std::result::Result<T, ClientError>;
2 changes: 1 addition & 1 deletion backends/v2/src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,6 @@ impl From<transport::Error> for ClientError {
}
}

static WARMUP_IMAGE_BASE64 :&str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=";
static WARMUP_IMAGE_BASE64: &str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=";

pub type Result<T> = std::result::Result<T, ClientError>;
2 changes: 1 addition & 1 deletion backends/v3/src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,6 @@ impl From<Chunk> for InputChunk {
}
}

static WARMUP_IMAGE_BASE64 :&str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=";
static WARMUP_IMAGE_BASE64: &str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=";

pub type Result<T> = std::result::Result<T, ClientError>;
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
{
"choices": [
{
"finish_reason": "stop",
"index": 0,
"logprobs": null,
"message": {
"content": "The correct answer is: blue",
"name": null,
"role": "assistant",
"tool_calls": null
},
"usage": null
}
],
"created": 1733445131,
"id": "",
"model": "Qwen/Qwen2-VL-2B-Instruct",
"object": "chat.completion",
"system_fingerprint": "2.4.2-dev0-native",
"usage": {
"completion_tokens": 7,
"prompt_tokens": 27,
"total_tokens": 34
}
}
38 changes: 38 additions & 0 deletions integration-tests/models/test_flash_qwen2_vl_warmup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import pytest


@pytest.fixture(scope="module")
def flash_qwen2_vl_handle(launcher):
with launcher(
"Qwen/Qwen2-VL-2B-Instruct",
max_input_length=40,
max_batch_prefill_tokens=50,
max_total_tokens=51,
) as handle:
yield handle


@pytest.fixture(scope="module")
async def flash_qwen2(flash_qwen2_vl_handle):
await flash_qwen2_vl_handle.health(300)
return flash_qwen2_vl_handle.client


@pytest.mark.private
async def test_flash_qwen2_vl_simple(flash_qwen2, response_snapshot):
response = await flash_qwen2.chat(
max_tokens=20,
seed=42,
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "What is the color of the sky?"},
],
},
],
)

assert response.choices[0].message.content == "The correct answer is: blue"

assert response == response_snapshot
6 changes: 6 additions & 0 deletions server/text_generation_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
BloomForCausalLM,
)
from text_generation_server.models.globals import ATTENTION
import text_generation_server.models.globals as globals
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
from text_generation_server.models.galactica import GalacticaCausalLMBatch
from text_generation_server.models.custom_modeling.neox_modeling import (
Expand Down Expand Up @@ -1208,6 +1209,11 @@ def get_model(
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == QWEN2_VL:
# TODO: remove edge case when cuda graph issue is resolved for BS=2 with Qwen2-VL
logger.warning(
"Qwen2-VL requires cuda graphs to be greater than 2. Removing all cuda graphs with a batch size equal or less than 2."
)
globals.CUDA_GRAPHS = list(filter(lambda x: x > 2, globals.CUDA_GRAPHS))
return VlmCausalLM(
model_id=model_id,
model_class=Qwen2VLForConditionalGeneration,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,12 @@ def forward(
dim=-1,
)

self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
self.rotary_emb(
query,
torch.select(kv, dim=1, index=0),
cos[: query.shape[0], ...],
sin[: query.shape[0], ...],
)

if prefill_cache_indices is not None:
kv_to_cache = kv[prefill_cache_indices]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,7 @@ def forward(
).squeeze(0)
inputs_embeds[input_ids == self.image_token_id] = image_embeds

max_s = max(max_s, inputs_embeds.size(0))
hidden_states = self.text_model(
inputs_embeds=inputs_embeds,
position_ids=position_ids,
Expand Down
18 changes: 11 additions & 7 deletions server/text_generation_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,13 @@
MEM_POOL,
ATTENTION,
BLOCK_SIZE,
CUDA_GRAPHS,
REQUEST_LOGPROBS,
TGI_WIGGLE_ROOM,
get_adapter_to_index,
)

# avoid coping CUDA_GRAPHS value by importing globals as a module
import text_generation_server.models.globals as globals
from text_generation_server.layers.attention import KVCache, Seqlen
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
from text_generation_server.utils.dist import MEMORY_FRACTION
Expand Down Expand Up @@ -1633,8 +1635,8 @@ def warmup(
int(val)
for val in os.environ["PYTORCH_TUNABLEOP_SEQLENS"].split(",")
]
elif CUDA_GRAPHS is not None:
tuning_sequences = CUDA_GRAPHS
elif globals.CUDA_GRAPHS is not None:
tuning_sequences = globals.CUDA_GRAPHS
else:
tuning_sequences = [1, 2, 3, 4, 5, 6, 7]

Expand Down Expand Up @@ -1673,13 +1675,14 @@ def warmup(
"PyTorch ROCm TunableOp (https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable) is disabled. TunableOp brings an additional 5-8% latency improvement for small sequence lengths but requires a warmup. If necessary, please use the environment variable PYTORCH_TUNABLEOP_ENABLED=1 to enable TunableOp.",
)

if CUDA_GRAPHS:
if globals.CUDA_GRAPHS:
try:
log_master(
logger.info, f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}"
logger.info,
f"Cuda Graphs are enabled for sizes {globals.CUDA_GRAPHS}",
)
# Warmup cuda graphs
for bs in CUDA_GRAPHS:
for bs in globals.CUDA_GRAPHS:
synchronize(self.device)
free_memory = get_free_memory(
self.device, MEMORY_FRACTION * TGI_WIGGLE_ROOM
Expand All @@ -1703,7 +1706,8 @@ def warmup(
logger.exception("Decode cuda graph warmup failed")
else:
log_master(
logger.info, f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS})."
logger.info,
f"Cuda Graphs are disabled (CUDA_GRAPHS={globals.CUDA_GRAPHS}).",
)

assert max_input_tokens is not None
Expand Down
23 changes: 10 additions & 13 deletions server/text_generation_server/models/vlm_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,6 @@ def batch_tokenized_inputs(
w = image.width * 2
h = image.height * 2
image = image.resize((w, h))

if config.model_type == "llava_next":
images.append(image)
else:
Expand All @@ -198,8 +197,8 @@ def batch_tokenized_inputs(
else:
image_inputs = None

batch_inputs = []
max_truncation = 0
batch_tokenized_inputs = []
max_length = 0
image_id = 0
for r in requests:
full_text = ""
Expand All @@ -214,16 +213,14 @@ def batch_tokenized_inputs(
image_id += 1

full_text = image_text_replacement_fixup(config, full_text)

batch_inputs.append(full_text)
max_truncation = max(max_truncation, r.truncate)

batch_tokenized_inputs = tokenizer(
batch_inputs,
truncation=True,
max_length=max_truncation,
add_special_tokens=not config.model_type == "paligemma",
)["input_ids"]
input_ids = tokenizer(
full_text,
truncation=True,
max_length=r.truncate,
add_special_tokens=r.add_special_tokens,
)["input_ids"]
max_length = max(max_length, len(input_ids))
batch_tokenized_inputs.append(input_ids)

return batch_tokenized_inputs, image_inputs

Expand Down
Loading