diff --git a/src/forward_to_hf_endpoint.rs b/src/forward_to_hf_endpoint.rs index e070646..642050c 100644 --- a/src/forward_to_hf_endpoint.rs +++ b/src/forward_to_hf_endpoint.rs @@ -11,16 +11,12 @@ use crate::call_validation::SamplingParameters; pub async fn forward_to_hf_style_endpoint( - save_url: &mut String, + save_url: &String, bearer: String, - model_name: &str, prompt: &str, client: &reqwest::Client, - endpoint_template: &String, sampling_parameters: &SamplingParameters, ) -> Result { - let url = endpoint_template.replace("$MODEL", model_name); - save_url.clone_from(&&url); let mut headers = HeaderMap::new(); headers.insert(CONTENT_TYPE, HeaderValue::from_str("application/json").unwrap()); if !bearer.is_empty() { @@ -34,7 +30,7 @@ pub async fn forward_to_hf_style_endpoint( "inputs": prompt, "parameters": params_json, }); - let req = client.post(&url) + let req = client.post(save_url) .headers(headers) .body(data.to_string()) .send() @@ -42,26 +38,22 @@ pub async fn forward_to_hf_style_endpoint( let resp = req.map_err(|e| format!("{}", e))?; let status_code = resp.status().as_u16(); let response_txt = resp.text().await.map_err(|e| - format!("reading from socket {}: {}", url, e) + format!("reading from socket {}: {}", save_url, e) )?; if status_code != 200 { - return Err(format!("{} status={} text {}", url, status_code, response_txt)); + return Err(format!("{} status={} text {}", save_url, status_code, response_txt)); } Ok(serde_json::from_str(&response_txt).unwrap()) } pub async fn forward_to_hf_style_endpoint_streaming( - save_url: &mut String, + save_url: &String, bearer: String, - model_name: &str, prompt: &str, client: &reqwest::Client, - endpoint_template: &String, sampling_parameters: &SamplingParameters, ) -> Result { - let url = endpoint_template.replace("$MODEL", model_name); - save_url.clone_from(&&url); let mut headers = HeaderMap::new(); headers.insert(CONTENT_TYPE, HeaderValue::from_str("application/json").unwrap()); if !bearer.is_empty() { @@ -77,11 +69,11 @@ pub async fn forward_to_hf_style_endpoint_streaming( "stream": true, }); - let builder = client.post(&url) + let builder = client.post(save_url) .headers(headers) .body(data.to_string()); let event_source: EventSource = EventSource::new(builder).map_err(|e| - format!("can't stream from {}: {}", url, e) + format!("can't stream from {}: {}", save_url, e) )?; Ok(event_source) } diff --git a/src/forward_to_openai_endpoint.rs b/src/forward_to_openai_endpoint.rs index bdca278..9a17d86 100644 --- a/src/forward_to_openai_endpoint.rs +++ b/src/forward_to_openai_endpoint.rs @@ -8,16 +8,13 @@ use crate::call_validation::SamplingParameters; pub async fn forward_to_openai_style_endpoint( - mut save_url: &String, + save_url: &String, bearer: String, model_name: &str, prompt: &str, client: &reqwest::Client, - endpoint_template: &String, sampling_parameters: &SamplingParameters, ) -> Result { - let url = endpoint_template.replace("$MODEL", model_name); - save_url.clone_from(&&url); let mut headers = HeaderMap::new(); headers.insert(CONTENT_TYPE, HeaderValue::from_str("application/json").unwrap()); if !bearer.is_empty() { @@ -31,7 +28,7 @@ pub async fn forward_to_openai_style_endpoint( "temperature": sampling_parameters.temperature, "max_tokens": sampling_parameters.max_new_tokens, }); - let req = client.post(&url) + let req = client.post(save_url) .headers(headers) .body(data.to_string()) .send() @@ -39,26 +36,23 @@ pub async fn forward_to_openai_style_endpoint( let resp = req.map_err(|e| format!("{}", e))?; let status_code = resp.status().as_u16(); let response_txt = resp.text().await.map_err(|e| - format!("reading from socket {}: {}", url, e) + format!("reading from socket {}: {}", save_url, e) )?; // info!("forward_to_openai_style_endpoint: {} {}\n{}", url, status_code, response_txt); if status_code != 200 { - return Err(format!("{} status={} text {}", url, status_code, response_txt)); + return Err(format!("{} status={} text {}", save_url, status_code, response_txt)); } Ok(serde_json::from_str(&response_txt).unwrap()) } pub async fn forward_to_openai_style_endpoint_streaming( - mut save_url: &String, + save_url: &String, bearer: String, model_name: &str, prompt: &str, client: &reqwest::Client, - endpoint_template: &String, sampling_parameters: &SamplingParameters, ) -> Result { - let url = endpoint_template.replace("$MODEL", model_name); - save_url.clone_from(&&url); let mut headers = HeaderMap::new(); headers.insert(CONTENT_TYPE, HeaderValue::from_str("application/json").unwrap()); if !bearer.is_empty() { @@ -72,11 +66,11 @@ pub async fn forward_to_openai_style_endpoint_streaming( "temperature": sampling_parameters.temperature, "max_tokens": sampling_parameters.max_new_tokens, }); - let builder = client.post(&url) + let builder = client.post(save_url) .headers(headers) .body(data.to_string()); let event_source: EventSource = EventSource::new(builder).map_err(|e| - format!("can't stream from {}: {}", url, e) + format!("can't stream from {}: {}", save_url, e) )?; Ok(event_source) } diff --git a/src/restream.rs b/src/restream.rs index b88a17f..480c75a 100644 --- a/src/restream.rs +++ b/src/restream.rs @@ -33,25 +33,23 @@ pub async fn scratchpad_interaction_not_stream( let caps_locked = caps.read().unwrap(); (caps_locked.endpoint_style.clone(), caps_locked.endpoint_template.clone(), cx.telemetry.clone()) }; - let mut save_url: String = String::new(); + let save_url: String = endpoint_template.replace("$MODEL", &model_name).clone(); + let model_says = if endpoint_style == "hf" { forward_to_hf_endpoint::forward_to_hf_style_endpoint( - &mut save_url, + &save_url, bearer.clone(), - &model_name, &prompt, &client, - &endpoint_template, ¶meters, ).await } else { forward_to_openai_endpoint::forward_to_openai_style_endpoint( - &mut save_url, + &save_url, bearer.clone(), &model_name, &prompt, &client, - &endpoint_template, ¶meters, ).await }.map_err(|e| { @@ -141,30 +139,27 @@ pub async fn scratchpad_interaction_stream( let caps_locked = caps.read().unwrap(); (caps_locked.endpoint_style.clone(), caps_locked.endpoint_template.clone(), cx.telemetry.clone()) }; - let mut save_url: String = String::new(); + let save_url: String = endpoint_template.replace("$MODEL", &model_name).clone(); let mut event_source = if endpoint_style == "hf" { forward_to_hf_endpoint::forward_to_hf_style_endpoint_streaming( - &mut save_url, + &save_url, bearer.clone(), - &model_name, &prompt, &client, - &endpoint_template, ¶meters, ).await } else { forward_to_openai_endpoint::forward_to_openai_style_endpoint_streaming( - &mut save_url, + &save_url, bearer.clone(), &model_name, &prompt, &client, - &endpoint_template, ¶meters, ).await }.map_err(|e| { tele_storage.write().unwrap().tele_net.push(telemetry_basic::TelemetryNetwork::new( - save_url.clone(), + endpoint_template.clone().replace("$MODEL", &model_name), scope.clone(), false, e.to_string(),