Skip to content
This repository has been archived by the owner on Oct 19, 2023. It is now read-only.

Commit

Permalink
fixed "" in url field
Browse files Browse the repository at this point in the history
  • Loading branch information
valaises committed Oct 18, 2023
1 parent efa8c84 commit 7f2c854
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 41 deletions.
22 changes: 7 additions & 15 deletions src/forward_to_hf_endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,12 @@ use crate::call_validation::SamplingParameters;


pub async fn forward_to_hf_style_endpoint(
save_url: &mut String,
save_url: &String,
bearer: String,
model_name: &str,
prompt: &str,
client: &reqwest::Client,
endpoint_template: &String,
sampling_parameters: &SamplingParameters,
) -> Result<serde_json::Value, String> {
let url = endpoint_template.replace("$MODEL", model_name);
save_url.clone_from(&&url);
let mut headers = HeaderMap::new();
headers.insert(CONTENT_TYPE, HeaderValue::from_str("application/json").unwrap());
if !bearer.is_empty() {
Expand All @@ -34,34 +30,30 @@ pub async fn forward_to_hf_style_endpoint(
"inputs": prompt,
"parameters": params_json,
});
let req = client.post(&url)
let req = client.post(save_url)
.headers(headers)
.body(data.to_string())
.send()
.await;
let resp = req.map_err(|e| format!("{}", e))?;
let status_code = resp.status().as_u16();
let response_txt = resp.text().await.map_err(|e|
format!("reading from socket {}: {}", url, e)
format!("reading from socket {}: {}", save_url, e)
)?;
if status_code != 200 {
return Err(format!("{} status={} text {}", url, status_code, response_txt));
return Err(format!("{} status={} text {}", save_url, status_code, response_txt));
}
Ok(serde_json::from_str(&response_txt).unwrap())
}


pub async fn forward_to_hf_style_endpoint_streaming(
save_url: &mut String,
save_url: &String,
bearer: String,
model_name: &str,
prompt: &str,
client: &reqwest::Client,
endpoint_template: &String,
sampling_parameters: &SamplingParameters,
) -> Result<EventSource, String> {
let url = endpoint_template.replace("$MODEL", model_name);
save_url.clone_from(&&url);
let mut headers = HeaderMap::new();
headers.insert(CONTENT_TYPE, HeaderValue::from_str("application/json").unwrap());
if !bearer.is_empty() {
Expand All @@ -77,11 +69,11 @@ pub async fn forward_to_hf_style_endpoint_streaming(
"stream": true,
});

let builder = client.post(&url)
let builder = client.post(save_url)
.headers(headers)
.body(data.to_string());
let event_source: EventSource = EventSource::new(builder).map_err(|e|
format!("can't stream from {}: {}", url, e)
format!("can't stream from {}: {}", save_url, e)
)?;
Ok(event_source)
}
20 changes: 7 additions & 13 deletions src/forward_to_openai_endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,13 @@ use crate::call_validation::SamplingParameters;


pub async fn forward_to_openai_style_endpoint(
mut save_url: &String,
save_url: &String,
bearer: String,
model_name: &str,
prompt: &str,
client: &reqwest::Client,
endpoint_template: &String,
sampling_parameters: &SamplingParameters,
) -> Result<serde_json::Value, String> {
let url = endpoint_template.replace("$MODEL", model_name);
save_url.clone_from(&&url);
let mut headers = HeaderMap::new();
headers.insert(CONTENT_TYPE, HeaderValue::from_str("application/json").unwrap());
if !bearer.is_empty() {
Expand All @@ -31,34 +28,31 @@ pub async fn forward_to_openai_style_endpoint(
"temperature": sampling_parameters.temperature,
"max_tokens": sampling_parameters.max_new_tokens,
});
let req = client.post(&url)
let req = client.post(save_url)
.headers(headers)
.body(data.to_string())
.send()
.await;
let resp = req.map_err(|e| format!("{}", e))?;
let status_code = resp.status().as_u16();
let response_txt = resp.text().await.map_err(|e|
format!("reading from socket {}: {}", url, e)
format!("reading from socket {}: {}", save_url, e)
)?;
// info!("forward_to_openai_style_endpoint: {} {}\n{}", url, status_code, response_txt);
if status_code != 200 {
return Err(format!("{} status={} text {}", url, status_code, response_txt));
return Err(format!("{} status={} text {}", save_url, status_code, response_txt));
}
Ok(serde_json::from_str(&response_txt).unwrap())
}

pub async fn forward_to_openai_style_endpoint_streaming(
mut save_url: &String,
save_url: &String,
bearer: String,
model_name: &str,
prompt: &str,
client: &reqwest::Client,
endpoint_template: &String,
sampling_parameters: &SamplingParameters,
) -> Result<EventSource, String> {
let url = endpoint_template.replace("$MODEL", model_name);
save_url.clone_from(&&url);
let mut headers = HeaderMap::new();
headers.insert(CONTENT_TYPE, HeaderValue::from_str("application/json").unwrap());
if !bearer.is_empty() {
Expand All @@ -72,11 +66,11 @@ pub async fn forward_to_openai_style_endpoint_streaming(
"temperature": sampling_parameters.temperature,
"max_tokens": sampling_parameters.max_new_tokens,
});
let builder = client.post(&url)
let builder = client.post(save_url)
.headers(headers)
.body(data.to_string());
let event_source: EventSource = EventSource::new(builder).map_err(|e|
format!("can't stream from {}: {}", url, e)
format!("can't stream from {}: {}", save_url, e)
)?;
Ok(event_source)
}
21 changes: 8 additions & 13 deletions src/restream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,25 +33,23 @@ pub async fn scratchpad_interaction_not_stream(
let caps_locked = caps.read().unwrap();
(caps_locked.endpoint_style.clone(), caps_locked.endpoint_template.clone(), cx.telemetry.clone())
};
let mut save_url: String = String::new();
let save_url: String = endpoint_template.replace("$MODEL", &model_name).clone();

let model_says = if endpoint_style == "hf" {
forward_to_hf_endpoint::forward_to_hf_style_endpoint(
&mut save_url,
&save_url,
bearer.clone(),
&model_name,
&prompt,
&client,
&endpoint_template,
&parameters,
).await
} else {
forward_to_openai_endpoint::forward_to_openai_style_endpoint(
&mut save_url,
&save_url,
bearer.clone(),
&model_name,
&prompt,
&client,
&endpoint_template,
&parameters,
).await
}.map_err(|e| {
Expand Down Expand Up @@ -141,30 +139,27 @@ pub async fn scratchpad_interaction_stream(
let caps_locked = caps.read().unwrap();
(caps_locked.endpoint_style.clone(), caps_locked.endpoint_template.clone(), cx.telemetry.clone())
};
let mut save_url: String = String::new();
let save_url: String = endpoint_template.replace("$MODEL", &model_name).clone();
let mut event_source = if endpoint_style == "hf" {
forward_to_hf_endpoint::forward_to_hf_style_endpoint_streaming(
&mut save_url,
&save_url,
bearer.clone(),
&model_name,
&prompt,
&client,
&endpoint_template,
&parameters,
).await
} else {
forward_to_openai_endpoint::forward_to_openai_style_endpoint_streaming(
&mut save_url,
&save_url,
bearer.clone(),
&model_name,
&prompt,
&client,
&endpoint_template,
&parameters,
).await
}.map_err(|e| {
tele_storage.write().unwrap().tele_net.push(telemetry_basic::TelemetryNetwork::new(
save_url.clone(),
endpoint_template.clone().replace("$MODEL", &model_name),
scope.clone(),
false,
e.to_string(),
Expand Down

0 comments on commit 7f2c854

Please sign in to comment.