diff --git a/launcher/src/main.rs b/launcher/src/main.rs index d2ca38e5a0e..7a7909b683b 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -1162,6 +1162,7 @@ fn spawn_webserver( max_input_tokens: usize, max_total_tokens: usize, max_batch_prefill_tokens: u32, + download_time: u64, shutdown: Arc, shutdown_receiver: &mpsc::Receiver<()>, ) -> Result { @@ -1275,6 +1276,8 @@ fn spawn_webserver( envs.push(("COMPUTE_TYPE".into(), compute_type.into())) } + envs.push(("DOWNLOAD_TIME".into(), download_time.to_string().into())); + let mut webserver = match Command::new("text-generation-router") .args(router_args) .envs(envs) @@ -1341,6 +1344,7 @@ fn terminate(process_name: &str, mut process: Child, timeout: Duration) -> io::R fn main() -> Result<(), LauncherError> { // Pattern match configuration let args: Args = Args::parse(); + let start_time = Instant::now(); // Filter events with LOG_LEVEL let varname = "LOG_LEVEL"; @@ -1622,12 +1626,14 @@ fn main() -> Result<(), LauncherError> { return Ok(()); } + let download_time = start_time.elapsed().as_secs(); let mut webserver = spawn_webserver( num_shard, args, max_input_tokens, max_total_tokens, max_batch_prefill_tokens, + download_time, shutdown.clone(), &shutdown_receiver, ) diff --git a/router/src/server.rs b/router/src/server.rs index d3a280ca3c8..1ea70ec8837 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -54,6 +54,7 @@ use tower_http::cors::{AllowOrigin, CorsLayer}; use tracing::{info_span, instrument, Instrument}; use utoipa::OpenApi; use utoipa_swagger_ui::SwaggerUi; +use tokio::time::Duration; /// Generate tokens if `stream == false` or a stream of token if `stream == true` #[utoipa::path( @@ -1512,6 +1513,8 @@ pub async fn run( ) )] struct ApiDoc; + let download_time = std::env::var("DOWNLOAD_TIME").unwrap_or("30".to_string()).parse::().unwrap_or(30); + let length_time = Instant::now(); // Create state if print_schema_command { @@ -1892,6 +1895,10 @@ pub async fn run( .layer(cors_layer); tracing::info!("Connected"); + let total_time = length_time.elapsed() + Duration::from_secs(download_time); + metrics::gauge!("tgi_model_load_time").set(total_time.as_secs_f64()); + + if ngrok { #[cfg(feature = "ngrok")]