diff --git a/launcher/src/main.rs b/launcher/src/main.rs index b33e698a4d2..7a7909b683b 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -1162,7 +1162,7 @@ fn spawn_webserver( max_input_tokens: usize, max_total_tokens: usize, max_batch_prefill_tokens: u32, - startup_time: u64, + download_time: u64, shutdown: Arc, shutdown_receiver: &mpsc::Receiver<()>, ) -> Result { @@ -1200,8 +1200,6 @@ fn spawn_webserver( format!("{}-0", args.shard_uds_path), "--tokenizer-name".to_string(), args.model_id, - "--startup-time".to_string(), - startup_time.to_string(), ]; // Grammar support @@ -1278,6 +1276,8 @@ fn spawn_webserver( envs.push(("COMPUTE_TYPE".into(), compute_type.into())) } + envs.push(("DOWNLOAD_TIME".into(), download_time.to_string().into())); + let mut webserver = match Command::new("text-generation-router") .args(router_args) .envs(envs) diff --git a/router/src/main.rs b/router/src/main.rs index 65b2b747812..b060d73cf30 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -87,8 +87,6 @@ struct Args { disable_grammar_support: bool, #[clap(default_value = "4", long, env)] max_client_batch_size: usize, - #[clap(long, env)] - startup_time: u64, } #[derive(Debug, Subcommand)] @@ -131,7 +129,6 @@ async fn main() -> Result<(), RouterError> { disable_grammar_support, max_client_batch_size, command, - startup_time, } = args; let print_schema_command = match command { @@ -381,8 +378,6 @@ async fn main() -> Result<(), RouterError> { } }; - tracing::info!("start time of the model is {startup_time}"); - // Run server server::run( master_shard_uds_path, @@ -414,7 +409,6 @@ async fn main() -> Result<(), RouterError> { disable_grammar_support, max_client_batch_size, print_schema_command, - startup_time, ) .await?; Ok(()) diff --git a/router/src/server.rs b/router/src/server.rs index 55ac3b16046..0f0d497118a 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1434,7 +1434,6 @@ pub async fn run( grammar_support: bool, max_client_batch_size: usize, print_schema_command: bool, - start_time: u64, ) -> Result<(), WebServerError> { // OpenAPI documentation #[derive(OpenApi)] @@ -1514,6 +1513,7 @@ pub async fn run( ) )] struct ApiDoc; + let download_time = std::env::var("DOWNLOAD_TIME").unwrap_or("30".to_string()).parse::().unwrap_or(30); let length_time = Instant::now(); // Create state @@ -1895,11 +1895,11 @@ pub async fn run( .layer(cors_layer); tracing::info!("Connected"); - let total_time = length_time.elapsed() + Duration::from_secs(start_time); + let total_time = length_time.elapsed() + Duration::from_secs(download_time); tracing::info!("total time for router to boot up and connect to model server {:?}", length_time.elapsed()); tracing::info!("the total time in secs of boot time is {:?}", total_time); metrics::gauge!("tgi_model_load_time").set(total_time.as_secs_f64()); - + if ngrok {