From 1dbcb775fd89da53f9f0c032fb5366b87650a9e7 Mon Sep 17 00:00:00 2001 From: Moritz Althaus Date: Tue, 10 Dec 2024 12:03:22 +0100 Subject: [PATCH] feat!: add option to ask for special tokens in completion response --- src/completion.rs | 49 ++++++++++++++++++++++++++++++-------------- src/lib.rs | 1 + src/prompt.rs | 1 + tests/integration.rs | 28 +++++++++++++++++++++++++ 4 files changed, 64 insertions(+), 15 deletions(-) diff --git a/src/completion.rs b/src/completion.rs index 68a340b..b5fd62c 100644 --- a/src/completion.rs +++ b/src/completion.rs @@ -11,6 +11,8 @@ pub struct TaskCompletion<'a> { pub stopping: Stopping<'a>, /// Sampling controls how the tokens ("words") are selected for the completion. pub sampling: Sampling<'a>, + /// Whether to include special tokens (e.g. <|endoftext|>, <|python_tag|>) in the completion. + pub special_tokens: bool, } impl<'a> TaskCompletion<'a> { @@ -20,6 +22,7 @@ impl<'a> TaskCompletion<'a> { prompt: Prompt::from_text(text), stopping: Stopping::NO_TOKEN_LIMIT, sampling: Sampling::MOST_LIKELY, + special_tokens: false, } } @@ -32,6 +35,12 @@ impl<'a> TaskCompletion<'a> { self.stopping.stop_sequences = stop_sequences; self } + + /// Include special tokens (e.g. <|endoftext|>, <|python_tag|>) in the completion. + pub fn with_special_tokens(mut self) -> Self { + self.special_tokens = true; + self + } } /// Sampling controls how the tokens ("words") are selected for the completion. @@ -144,6 +153,13 @@ struct BodyCompletion<'a> { /// If true, the response will be streamed. #[serde(skip_serializing_if = "std::ops::Not::not")] pub stream: bool, + /// Forces the raw completion of the model to be returned. + /// For some models, the completion that was generated by the model may be optimized and + /// returned in the completion field of the CompletionResponse. + /// The raw completion, if returned, will contain the un-optimized completion. + /// Setting tokens to true or log_probs to any value will also trigger the raw completion to be returned. + #[serde(skip_serializing_if = "std::ops::Not::not")] + pub raw_completion: bool, } impl<'a> BodyCompletion<'a> { @@ -158,6 +174,7 @@ impl<'a> BodyCompletion<'a> { top_p: task.sampling.top_p, completion_bias_inclusion: task.sampling.complete_with_one_of, stream: false, + raw_completion: task.special_tokens, } } pub fn with_streaming(mut self) -> Self { @@ -168,22 +185,15 @@ impl<'a> BodyCompletion<'a> { #[derive(Deserialize, Debug, PartialEq, Eq)] pub struct ResponseCompletion { - pub model_version: String, - pub completions: Vec, + model_version: String, + completions: Vec, } -impl ResponseCompletion { - /// The best completion in the answer. - pub fn best(&self) -> &CompletionOutput { - self.completions - .first() - .expect("Response is assumed to always have at least one completion") - } - - /// Text of the best completion. - pub fn best_text(&self) -> &str { - &self.best().completion - } +#[derive(Deserialize, Debug, PartialEq, Eq)] +struct DeserializedCompletion { + completion: String, + finish_reason: String, + raw_completion: Option, } /// Completion and metainformation returned by a completion task @@ -209,7 +219,16 @@ impl Task for TaskCompletion<'_> { } fn body_to_output(&self, mut response: Self::ResponseBody) -> Self::Output { - response.completions.pop().unwrap() + let deserialized = response.completions.pop().unwrap(); + let completion = if self.special_tokens { + deserialized.raw_completion.unwrap() + } else { + deserialized.completion + }; + CompletionOutput { + completion, + finish_reason: deserialized.finish_reason, + } } } diff --git a/src/lib.rs b/src/lib.rs index bbfbb21..70fc1ae 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -324,6 +324,7 @@ impl Client { /// prompt: prompt.clone(), /// stopping: Stopping::from_maximum_tokens(10), /// sampling: Sampling::MOST_LIKELY, + /// special_tokens: false, /// }; /// let response = client.completion(&task, model, &How::default()).await?; /// diff --git a/src/prompt.rs b/src/prompt.rs index 21c1188..ea2128c 100644 --- a/src/prompt.rs +++ b/src/prompt.rs @@ -96,6 +96,7 @@ impl<'a> Modality<'a> { /// ]), /// stopping: Stopping::from_maximum_tokens(10), /// sampling: Sampling::MOST_LIKELY, + /// special_tokens: false, /// }; /// // Execute /// let model = "luminous-base"; diff --git a/tests/integration.rs b/tests/integration.rs index 782a926..c3e4a6c 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -61,6 +61,28 @@ async fn completion_with_luminous_base() { assert!(!response.completion.is_empty()) } +#[tokio::test] +async fn raw_completion_includes_python_tag() { + // When + let task = TaskCompletion::from_text( + "<|begin_of_text|><|start_header_id|>system<|end_header_id|> + +Environment: ipython<|eot_id|><|start_header_id|>user<|end_header_id|> + +Write code to check if number is prime, use that to see if the number 7 is prime<|eot_id|><|start_header_id|>assistant<|end_header_id|>", + ) + .with_maximum_tokens(30) + .with_special_tokens(); + + let model = "llama-3.1-8b-instruct"; + let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap(); + let response = client + .output_of(&task.with_model(model), &How::default()) + .await + .unwrap(); + assert!(response.completion.trim().starts_with("<|python_tag|>")); +} + #[tokio::test] async fn request_authentication_has_priority() { let bad_pharia_ai_token = "DUMMY"; @@ -201,6 +223,7 @@ async fn complete_structured_prompt() { stop_sequences: &stop_sequences[..], }, sampling: Sampling::MOST_LIKELY, + special_tokens: false, }; let model = "luminous-base"; let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap(); @@ -230,6 +253,7 @@ async fn maximum_tokens_none_request() { prompt: Prompt::from_text(prompt), stopping, sampling: Sampling::MOST_LIKELY, + special_tokens: false, }; let model = "luminous-base"; let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap(); @@ -363,6 +387,7 @@ async fn describe_image_starting_from_a_path() { ]), stopping: Stopping::from_maximum_tokens(10), sampling: Sampling::MOST_LIKELY, + special_tokens: false, }; let model = "luminous-base"; let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap(); @@ -392,6 +417,7 @@ async fn describe_image_starting_from_a_dyn_image() { ]), stopping: Stopping::from_maximum_tokens(10), sampling: Sampling::MOST_LIKELY, + special_tokens: false, }; let model = "luminous-base"; let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap(); @@ -418,6 +444,7 @@ async fn only_answer_with_specific_animal() { complete_with_one_of: &[" dog"], ..Default::default() }, + special_tokens: false, }; let model = "luminous-base"; let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap(); @@ -444,6 +471,7 @@ async fn answer_should_continue() { complete_with_one_of: &[" Says.", " Art.", " Weekend."], ..Default::default() }, + special_tokens: false, }; let model = "luminous-base"; let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap();