Skip to content
This repository has been archived by the owner on Oct 19, 2023. It is now read-only.

Commit

Permalink
stop_toks + stop_length
Browse files Browse the repository at this point in the history
  • Loading branch information
olegklimov committed Oct 10, 2023
1 parent 8eb2cd5 commit e70a3c5
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 17 deletions.
10 changes: 6 additions & 4 deletions src/restream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,17 +190,19 @@ pub async fn scratchpad_interaction_stream(
if let Some(token) = json.get("token") { // hf style produces this
let text = token.get("text").unwrap().as_str().unwrap().to_string();
let mut value: serde_json::Value;
(value, finished) = scratch.response_streaming(text, false).unwrap();
(value, finished) = scratch.response_streaming(text, false, false).unwrap();
value["created"] = json!(t1.duration_since(std::time::UNIX_EPOCH).unwrap().as_millis() as f64 / 1000.0);
value["model"] = json!(model_name.clone());
value_str = format!("data: {}\n\n", serde_json::to_string(&value).unwrap());
was_correct_output_even_if_error |= json.get("generated_text").is_some();
} else if let Some(choices) = json.get("choices") { // openai style
let choice0 = &choices[0];
let text = choice0.get("text").unwrap().as_str().unwrap().to_string();
let stopped = choice0.get("finish_reason").unwrap_or(&json!("")).as_str().unwrap().to_string().starts_with("stop");
let finish_reason = choice0.get("finish_reason").unwrap_or(&json!("")).as_str().unwrap().to_string();
let stop_toks = !finish_reason.is_empty() && finish_reason.starts_with("stop");
let stop_length = !finish_reason.is_empty() && !finish_reason.starts_with("stop");
let mut value: serde_json::Value;
(value, finished) = scratch.response_streaming(text, stopped).unwrap();
(value, finished) = scratch.response_streaming(text, stop_toks, stop_length).unwrap();
value["created"] = json!(t1.duration_since(std::time::UNIX_EPOCH).unwrap().as_millis() as f64 / 1000.0);
model_name = json["model"].as_str().unwrap().to_string();
value["model"] = json!(model_name.clone());
Expand Down Expand Up @@ -240,7 +242,7 @@ pub async fn scratchpad_interaction_stream(
return;
} else if !finished {
let mut value: serde_json::Value;
(value, _) = scratch.response_streaming("".to_string(), true).unwrap();
(value, _) = scratch.response_streaming("".to_string(), false, true).unwrap();
value["created"] = json!(t1.duration_since(std::time::UNIX_EPOCH).unwrap().as_millis() as f64 / 1000.0);
value["model"] = json!(model_name.clone());
let value_str = format!("data: {}\n\n", serde_json::to_string(&value).unwrap());
Expand Down
3 changes: 2 additions & 1 deletion src/scratchpad_abstract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ pub trait ScratchpadAbstract: Send {
fn response_streaming( // Only 1 choice, but streaming. Returns delta the user should see, and finished flag
&mut self,
delta: String, // if delta is empty, there is no more input, add final fields if needed
stopped: bool,
stop_toks: bool,
stop_length: bool,
) -> Result<(serde_json::Value, bool), String>;
}

Expand Down
5 changes: 3 additions & 2 deletions src/scratchpads/chat_generic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,10 @@ impl ScratchpadAbstract for GenericChatScratchpad {
fn response_streaming(
&mut self,
delta: String,
stopped: bool,
stop_toks: bool,
stop_length: bool,
) -> Result<(serde_json::Value, bool), String> {
self.dd.response_streaming(delta, stopped)
self.dd.response_streaming(delta, stop_toks)
}
}

5 changes: 3 additions & 2 deletions src/scratchpads/chat_llama2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,10 @@ impl ScratchpadAbstract for ChatLlama2 {
fn response_streaming(
&mut self,
delta: String,
stopped: bool,
stop_toks: bool,
stop_length: bool,
) -> Result<(serde_json::Value, bool), String> {
self.dd.response_streaming(delta, stopped)
self.dd.response_streaming(delta, stop_toks)
}
}

19 changes: 11 additions & 8 deletions src/scratchpads/completion_single_file_fim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,20 +195,22 @@ impl ScratchpadAbstract for SingleFileFIM {
fn response_streaming(
&mut self,
delta: String,
stopped: bool,
stop_toks: bool,
stop_length: bool,
) -> Result<(serde_json::Value, bool), String> {
let mut finished = false;
let mut finished;
let json_choices;
if !delta.is_empty() {
// info!("XXXXX delta: {:?}", delta);
// info!("XXXXX stop_toks: {:?}", stop_toks);
// info!("XXXXX stop_length: {:?}", stop_length);
if !delta.is_empty() || stop_toks {
let mut s: String;
(s, finished) = cut_result(&delta, self.t.eot.as_str(), self.post.inputs.multiline);
if finished {
self.data4cache.completion0_finish_reason = "stop".to_string();
}
finished |= stopped;
finished |= stop_toks;
if finished {
// can stay consistent with trim() only if that's the final iteration
s = s.trim_end().to_string();
self.data4cache.completion0_finish_reason = if finished { "stop".to_string() } else { "".to_string() };
}
self.data4cache.completion0_text.push_str(&s);
json_choices = serde_json::json!([{
Expand All @@ -217,13 +219,14 @@ impl ScratchpadAbstract for SingleFileFIM {
"finish_reason": if finished { serde_json::Value::String("stop".to_string()) } else { serde_json::Value::Null },
}]);
} else {
assert!(stopped);
assert!(stop_length);
json_choices = serde_json::json!([{
"index": 0,
"code_completion": "",
"finish_reason": "length"
}]);
self.data4cache.completion0_finish_reason = "length".to_string();
finished = true;
}
telemetry_snippets::snippet_register_from_data4cache(&self.data4snippet, &mut self.data4cache);
let ans = serde_json::json!({
Expand Down

0 comments on commit e70a3c5

Please sign in to comment.