Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add payload limit #2726

Merged
merged 2 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions backends/trtllm/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ struct Args {
executor_worker: PathBuf,
#[clap(default_value = "on", long, env)]
usage_stats: usage_stats::UsageStatsLevel,
#[clap(default_value = "2000000", long, env)]
payload_limit: usize,
}

async fn get_tokenizer(
Expand Down Expand Up @@ -217,6 +219,7 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
auth_token,
executor_worker,
usage_stats,
payload_limit,
} = args;

// Launch Tokio runtime
Expand Down Expand Up @@ -287,6 +290,7 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
tokenizer_name,
tokenizer_config_path,
revision,
false,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this intentional ?

hostname,
port,
cors_allow_origin,
Expand All @@ -296,6 +300,7 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
true,
max_client_batch_size,
usage_stats,
payload_limit,
)
.await?;
Ok(())
Expand Down
4 changes: 4 additions & 0 deletions backends/v2/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ struct Args {
max_client_batch_size: usize,
#[clap(default_value = "on", long, env)]
usage_stats: usage_stats::UsageStatsLevel,
#[clap(default_value = "2000000", long, env)]
payload_limit: usize,
}

#[derive(Debug, Subcommand)]
Expand Down Expand Up @@ -114,6 +116,7 @@ async fn main() -> Result<(), RouterError> {
disable_grammar_support,
max_client_batch_size,
usage_stats,
payload_limit,
} = args;

if let Some(Commands::PrintSchema) = command {
Expand Down Expand Up @@ -194,6 +197,7 @@ async fn main() -> Result<(), RouterError> {
disable_grammar_support,
max_client_batch_size,
usage_stats,
payload_limit,
)
.await?;
Ok(())
Expand Down
4 changes: 4 additions & 0 deletions backends/v3/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ struct Args {
max_client_batch_size: usize,
#[clap(default_value = "on", long, env)]
usage_stats: usage_stats::UsageStatsLevel,
#[clap(default_value = "2000000", long, env)]
payload_limit: usize,
}

#[derive(Debug, Subcommand)]
Expand Down Expand Up @@ -114,6 +116,7 @@ async fn main() -> Result<(), RouterError> {
disable_grammar_support,
max_client_batch_size,
usage_stats,
payload_limit,
} = args;

if let Some(Commands::PrintSchema) = command {
Expand Down Expand Up @@ -210,6 +213,7 @@ async fn main() -> Result<(), RouterError> {
disable_grammar_support,
max_client_batch_size,
usage_stats,
payload_limit,
)
.await?;
Ok(())
Expand Down
8 changes: 8 additions & 0 deletions launcher/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,12 @@ struct Args {
/// Defaul is on.
#[clap(default_value = "on", long, env)]
usage_stats: UsageStatsLevel,

/// Payload size limit in bytes
///
/// Default is 2MB
#[clap(default_value = "2000000", long, env)]
payload_limit: usize,
}

#[derive(Debug)]
Expand Down Expand Up @@ -1474,6 +1480,8 @@ fn spawn_webserver(
format!("{}-0", args.shard_uds_path),
"--tokenizer-name".to_string(),
args.model_id,
"--payload-limit".to_string(),
args.payload_limit.to_string(),
];
if let Some(max_input_tokens) = max_input_tokens {
router_args.extend_from_slice(&[
Expand Down
6 changes: 5 additions & 1 deletion router/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use crate::{
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType};
use crate::{ModelInfo, ModelsInfo};
use async_stream::__private::AsyncStream;
use axum::extract::Extension;
use axum::extract::{DefaultBodyLimit, Extension};
use axum::http::{HeaderMap, HeaderValue, Method, StatusCode};
use axum::response::sse::{Event, KeepAlive, Sse};
use axum::response::{IntoResponse, Response};
Expand Down Expand Up @@ -1673,6 +1673,7 @@ pub async fn run(
disable_grammar_support: bool,
max_client_batch_size: usize,
usage_stats_level: usage_stats::UsageStatsLevel,
payload_limit: usize,
) -> Result<(), WebServerError> {
// CORS allowed origins
// map to go inside the option and then map to parse from String to HeaderValue
Expand Down Expand Up @@ -1926,6 +1927,7 @@ pub async fn run(
model_info,
compat_return_full_text,
allow_origin,
payload_limit,
)
.await;

Expand Down Expand Up @@ -1985,6 +1987,7 @@ async fn start(
model_info: HubModelInfo,
compat_return_full_text: bool,
allow_origin: Option<AllowOrigin>,
payload_limit: usize,
) -> Result<(), WebServerError> {
// Determine the server port based on the feature and environment variable.
let port = if cfg!(feature = "google") {
Expand Down Expand Up @@ -2382,6 +2385,7 @@ async fn start(
.layer(Extension(compute_type))
.layer(Extension(prom_handle.clone()))
.layer(OtelAxumLayer::default())
.layer(DefaultBodyLimit::max(payload_limit))
.layer(cors_layer);

tracing::info!("Connected");
Expand Down
11 changes: 5 additions & 6 deletions server/text_generation_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,9 +962,9 @@ def prepare_for_prefill(self):
self.input_lengths_tensor = torch.tensor(
self.input_lengths, dtype=torch.int32, device=device
)
self.cu_seqlen_prefill = torch.nn.functional.pad(
torch.cumsum(self.input_lengths_tensor, dim=0), (1, 0)
).to(torch.int32)
cu_seqlen_prefill = self.input_lengths_tensor.new_zeros(len(self) + 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these changes in the same PR intentional?

torch.cumsum(self.input_lengths_tensor, out=cu_seqlen_prefill[1:], dim=0)
self.cu_seqlen_prefill = cu_seqlen_prefill.to(torch.int32)
self.cache_lengths_tensor = torch.tensor(
self.cache_lengths, dtype=torch.int32, device=device
)
Expand Down Expand Up @@ -2020,9 +2020,8 @@ def generate_token(

# For each member of the batch
# Cumulative length
cu_accepted_ids = torch.nn.functional.pad(
torch.cumsum(accepted_ids, dim=0), (1, 0)
)
cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1)
torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:])
cumulative_length = 0
for i, (
request,
Expand Down
5 changes: 3 additions & 2 deletions server/text_generation_server/models/metadata_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,9 @@ def block_tables_to_ragged(
)

if has_triton():
cu_seqlen = torch.nn.functional.pad(
torch.cumsum(input_lengths_tensor + cache_lengths_tensor, dim=0), (1, 0)
cu_seqlen = input_lengths_tensor.new_zeros(input_lengths_tensor.shape[0] + 1)
torch.cumsum(
input_lengths_tensor + cache_lengths_tensor, out=cu_seqlen[1:], dim=0
)

def grid(meta):
Expand Down
Loading