diff --git a/router/src/infer.rs b/router/src/infer.rs index eaa72a75403..472b7d66d5f 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -1,8 +1,8 @@ /// Batching and inference logic use crate::validation::{Validation, ValidationError}; use crate::{ - ChatTemplateInputs, CompletionTemplateInputs, Entry, GenerateRequest, GenerateStreamResponse, - HubTokenizerConfig, Message, PrefillToken, Queue, Token, + ChatTemplateInputs, Entry, GenerateRequest, GenerateStreamResponse, HubTokenizerConfig, + Message, PrefillToken, Queue, Token, }; use futures::future::try_join_all; use minijinja::{Environment, ErrorKind, Template}; @@ -33,8 +33,6 @@ pub struct Infer { shared: Arc, /// Chat template chat_template: Option, - /// Completion template - completion_template: Option, /// Inference limit limit_concurrent_requests: Arc, } @@ -90,10 +88,6 @@ impl Infer { .chat_template .map(|t| ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token)); - let completion_template = tokenizer_config - .completion_template - .map(CompletionTemplate::new); - // Inference limit with a semaphore let semaphore = Arc::new(Semaphore::new(max_concurrent_requests)); @@ -102,7 +96,6 @@ impl Infer { queue, shared, chat_template, - completion_template, limit_concurrent_requests: semaphore, } } @@ -193,24 +186,6 @@ impl Infer { }) } - /// Apply the completion template to the request - #[instrument(skip_all)] - pub(crate) fn apply_completion_template( - &self, - prompt: String, - suffix: Option, - ) -> Result { - self.completion_template - .as_ref() - .ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))? - .apply(prompt, suffix) - .map_err(|e| { - metrics::increment_counter!("tgi_request_failure", "err" => "template"); - tracing::error!("{e}"); - e - }) - } - /// Add a new request to the queue and return a InferResponse #[instrument(skip_all)] pub(crate) async fn generate( @@ -367,34 +342,6 @@ impl ChatTemplate { } } -#[derive(Clone)] -struct CompletionTemplate { - template: Template<'static, 'static>, -} - -impl CompletionTemplate { - fn new(template: String) -> Self { - let mut env = Box::new(Environment::new()); - let template_str = template.into_boxed_str(); - env.add_function("raise_exception", raise_exception); - // leaking env and template_str as read-only, static resources for performance. - let template = Box::leak(env) - .template_from_str(Box::leak(template_str)) - .unwrap(); - - Self { template } - } - - fn apply(&self, prompt: String, suffix: Option) -> Result { - self.template - .render(CompletionTemplateInputs { - prompt: prompt.as_str(), - suffix: suffix.as_deref(), - }) - .map_err(InferError::TemplateError) - } -} - /// Batching logic /// Will be launched in a background Tokio task /// diff --git a/router/src/lib.rs b/router/src/lib.rs index 70619c60b9b..694c3b66b7a 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -618,12 +618,6 @@ pub(crate) struct ChatTemplateInputs<'a> { add_generation_prompt: bool, } -#[derive(Clone, Serialize, Deserialize)] -pub(crate) struct CompletionTemplateInputs<'a> { - prompt: &'a str, - suffix: Option<&'a str>, -} - #[derive(Clone, Deserialize, ToSchema, Serialize)] pub(crate) struct Message { #[schema(example = "user")] diff --git a/router/src/server.rs b/router/src/server.rs index 020a976a062..78f3e2a97b0 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -579,24 +579,22 @@ async fn completions( let max_new_tokens = req.max_tokens.or(Some(100)); let seed = req.seed; - let inputs = match infer.apply_completion_template(req.prompt, req.suffix) { - Ok(inputs) => inputs, - Err(err) => { - metrics::increment_counter!("tgi_request_failure", "err" => "validation"); - tracing::error!("{err}"); - return Err(( - StatusCode::UNPROCESSABLE_ENTITY, - Json(ErrorResponse { - error: err.to_string(), - error_type: err.error_type().to_string(), - }), - )); - } - }; + // if suffix is present throw an error + if req.suffix.is_some() { + metrics::increment_counter!("tgi_request_failure", "err" => "validation"); + return Err(( + StatusCode::UNPROCESSABLE_ENTITY, + Json(ErrorResponse { + error: "Suffix is not supported and can be achieved by preprocessing the prompt." + .to_string(), + error_type: "suffix not supported".to_string(), + }), + )); + } // build the request passing some parameters let generate_request = GenerateRequest { - inputs: inputs.to_string(), + inputs: req.prompt.to_string(), parameters: GenerateParameters { best_of: None, temperature: req.temperature,