From f85cd58e2cf299760786a3cd00d527fe141274de Mon Sep 17 00:00:00 2001 From: drbh Date: Thu, 27 Jun 2024 18:34:26 +0000 Subject: [PATCH 1/3] fix: refactor post_processor logic and add test --- router/src/main.rs | 144 ++++++++++++++++++++++++++++++++++++--------- 1 file changed, 115 insertions(+), 29 deletions(-) diff --git a/router/src/main.rs b/router/src/main.rs index 3aa5a6bf9d2..bf468ce735a 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -304,38 +304,23 @@ async fn main() -> Result<(), RouterError> { tracing::warn!("Could not find tokenizer config locally and no API specified"); HubTokenizerConfig::default() }); - let tokenizer: Option = - tokenizer_filename.and_then(|filename| { - let mut tokenizer = Tokenizer::from_file(filename).ok(); - if let Some(tokenizer) = &mut tokenizer{ - if let Some(class) = &tokenizer_config.tokenizer_class{ - if class == "LlamaTokenizer" || class == "LlamaTokenizerFast" { - tracing::info!("Overriding LllamaTokenizer with TemplateProcessing to follow python override defined in https://github.com/huggingface/transformers/blob/4aa17d00690b7f82c95bb2949ea57e22c35b4336/src/transformers/models/llama/tokenization_llama_fast.py#L203-L205"); - let mut single = vec![]; - let mut special_tokens = vec![]; - if let Some(true) = &tokenizer_config.add_bos_token{ - if let Some(bos_token) = &tokenizer_config.bos_token{ - let bos_token_id = tokenizer.token_to_id(&bos_token).expect("Should have found the bos token id"); - special_tokens.push((bos_token.clone(), bos_token_id)); - single.push(bos_token.to_string()); - } - } - single.push("$0".to_string()); - if let Some(true) = &tokenizer_config.add_eos_token{ - if let Some(eos_token) = &tokenizer_config.eos_token{ - let eos_token_id = tokenizer.token_to_id(&eos_token).expect("Should have found the eos token id"); - special_tokens.push((eos_token.clone(), eos_token_id)); - single.push(eos_token.to_string()); - } - } - let post_processor = TemplateProcessing::builder().try_single(single).unwrap().special_tokens(special_tokens).build().unwrap(); + + let tokenizer: Option = tokenizer_filename.and_then(|filename| { + let mut tokenizer = Tokenizer::from_file(filename).ok(); + if let Some(tokenizer) = &mut tokenizer { + if let Some(class) = &tokenizer_config.tokenizer_class { + if class == "LlamaTokenizer" || class == "LlamaTokenizerFast" { + tracing::info!("Overriding LlamaTokenizer with TemplateProcessing to follow python override defined in https://github.com/huggingface/transformers/blob/4aa17d00690b7f82c95bb2949ea57e22c35b4336/src/transformers/models/llama/tokenization_llama_fast.py#L203-L205"); + if let Some(post_processor) = create_post_processor(tokenizer, &tokenizer_config) { tokenizer.with_post_processor(post_processor); - }} + } } - tokenizer - - }); + } + } + tokenizer + }); + // tokenizer.with_post_processor(post_processor); let preprocessor_config = preprocessor_config_filename.and_then(HubPreprocessorConfig::from_file); let processor_config = processor_config_filename @@ -543,6 +528,74 @@ pub async fn get_tokenizer_config(api_repo: &ApiRepo) -> Option Option { + let add_bos_token = tokenizer_config.add_bos_token.unwrap_or(true); + let add_eos_token = tokenizer_config.add_eos_token.unwrap_or(false); + + let bos_token = tokenizer_config.bos_token.as_ref(); + let eos_token = tokenizer_config.eos_token.as_ref(); + + if add_bos_token && bos_token.is_none() { + panic!("add_bos_token = true but bos_token is None"); + } + + if add_eos_token && eos_token.is_none() { + panic!("add_eos_token = true but eos_token is None"); + } + + let mut single = String::new(); + let mut pair = String::new(); + let mut special_tokens = Vec::new(); + + if add_bos_token { + let bos = bos_token.unwrap(); + let bos_token_id = tokenizer + .token_to_id(bos) + .expect("Should have found the bos token id"); + special_tokens.push((bos.clone(), bos_token_id)); + single.push_str(&format!("{}:0 ", bos)); + pair.push_str(&format!("{}:0 ", bos)); + } + + single.push_str("$A:0"); + pair.push_str("$A:0"); + + if add_eos_token { + let eos = eos_token.unwrap(); + let eos_token_id = tokenizer + .token_to_id(eos) + .expect("Should have found the eos token id"); + special_tokens.push((eos.clone(), eos_token_id)); + single.push_str(&format!(" {}:0", eos)); + pair.push_str(&format!(" {}:0", eos)); + } + + if add_bos_token { + pair.push_str(&format!(" {}:1", bos_token.unwrap())); + } + + pair.push_str(" $B:1"); + + if add_eos_token { + pair.push_str(&format!(" {}:1", eos_token.unwrap())); + } + + let post_processor = TemplateProcessing::builder() + .try_single(single) + .unwrap() + .try_pair(pair) + .unwrap() + .special_tokens(special_tokens) + .build() + .unwrap(); + + Some(post_processor) +} + #[derive(Debug, Error)] enum RouterError { #[error("Argument validation error: {0}")] @@ -552,3 +605,36 @@ enum RouterError { #[error("Tokio runtime failed to start: {0}")] Tokio(#[from] std::io::Error), } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_create_post_processor() { + let tokenizer_config = HubTokenizerConfig { + add_bos_token: None, + add_eos_token: None, + bos_token: Some("".to_string()), + eos_token: Some("".to_string()), + chat_template: None, + tokenizer_class: None, + completion_template: None, + }; + + let tokenizer = + Tokenizer::from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", None).unwrap(); + let post_processor = create_post_processor(&tokenizer, &tokenizer_config).unwrap(); + + let expected = TemplateProcessing::builder() + .try_single(":0 $A:0") + .unwrap() + .try_pair(":0 $A:0 :1 $B:1") + .unwrap() + .special_tokens(vec![("".to_string(), 1)]) + .build() + .unwrap(); + + assert_eq!(post_processor, expected); + } +} From 74535ce80f0ad680d43155cf691077cd4f665dc2 Mon Sep 17 00:00:00 2001 From: drbh Date: Thu, 27 Jun 2024 18:42:41 +0000 Subject: [PATCH 2/3] fix: remove dev comment --- router/src/main.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/router/src/main.rs b/router/src/main.rs index bf468ce735a..7ec1640eed5 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -320,7 +320,6 @@ async fn main() -> Result<(), RouterError> { tokenizer }); - // tokenizer.with_post_processor(post_processor); let preprocessor_config = preprocessor_config_filename.and_then(HubPreprocessorConfig::from_file); let processor_config = processor_config_filename From a921854d92a8cfe15ca48974435fa34ea00c8907 Mon Sep 17 00:00:00 2001 From: drbh Date: Thu, 27 Jun 2024 20:47:06 +0000 Subject: [PATCH 3/3] fix: adjust when post_processor is overridden and improve create_post_processor --- router/src/main.rs | 71 ++++++++++++++++++++++++---------------------- 1 file changed, 37 insertions(+), 34 deletions(-) diff --git a/router/src/main.rs b/router/src/main.rs index 7ec1640eed5..1e8093d881a 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -309,9 +309,9 @@ async fn main() -> Result<(), RouterError> { let mut tokenizer = Tokenizer::from_file(filename).ok(); if let Some(tokenizer) = &mut tokenizer { if let Some(class) = &tokenizer_config.tokenizer_class { - if class == "LlamaTokenizer" || class == "LlamaTokenizerFast" { - tracing::info!("Overriding LlamaTokenizer with TemplateProcessing to follow python override defined in https://github.com/huggingface/transformers/blob/4aa17d00690b7f82c95bb2949ea57e22c35b4336/src/transformers/models/llama/tokenization_llama_fast.py#L203-L205"); - if let Some(post_processor) = create_post_processor(tokenizer, &tokenizer_config) { + if (class == "LlamaTokenizer" || class == "LlamaTokenizerFast") && tokenizer.get_post_processor().is_none() { + if let Ok(post_processor) = create_post_processor(tokenizer, &tokenizer_config) { + tracing::info!("Overriding LlamaTokenizer with TemplateProcessing to follow python override defined in https://github.com/huggingface/transformers/blob/4aa17d00690b7f82c95bb2949ea57e22c35b4336/src/transformers/models/llama/tokenization_llama_fast.py#L203-L205"); tokenizer.with_post_processor(post_processor); } } @@ -531,7 +531,7 @@ pub async fn get_tokenizer_config(api_repo: &ApiRepo) -> Option Option { +) -> Result { let add_bos_token = tokenizer_config.add_bos_token.unwrap_or(true); let add_eos_token = tokenizer_config.add_eos_token.unwrap_or(false); @@ -546,53 +546,56 @@ pub fn create_post_processor( panic!("add_eos_token = true but eos_token is None"); } - let mut single = String::new(); - let mut pair = String::new(); + let mut single = Vec::new(); + let mut pair = Vec::new(); let mut special_tokens = Vec::new(); if add_bos_token { - let bos = bos_token.unwrap(); - let bos_token_id = tokenizer - .token_to_id(bos) - .expect("Should have found the bos token id"); - special_tokens.push((bos.clone(), bos_token_id)); - single.push_str(&format!("{}:0 ", bos)); - pair.push_str(&format!("{}:0 ", bos)); + if let Some(bos) = bos_token { + let bos_token_id = tokenizer + .token_to_id(bos) + .expect("Should have found the bos token id"); + special_tokens.push((bos.clone(), bos_token_id)); + single.push(format!("{}:0", bos)); + pair.push(format!("{}:0", bos)); + } } - single.push_str("$A:0"); - pair.push_str("$A:0"); + single.push("$A:0".to_string()); + pair.push("$A:0".to_string()); if add_eos_token { - let eos = eos_token.unwrap(); - let eos_token_id = tokenizer - .token_to_id(eos) - .expect("Should have found the eos token id"); - special_tokens.push((eos.clone(), eos_token_id)); - single.push_str(&format!(" {}:0", eos)); - pair.push_str(&format!(" {}:0", eos)); + if let Some(eos) = eos_token { + let eos_token_id = tokenizer + .token_to_id(eos) + .expect("Should have found the eos token id"); + special_tokens.push((eos.clone(), eos_token_id)); + single.push(format!("{}:0", eos)); + pair.push(format!("{}:0", eos)); + } } if add_bos_token { - pair.push_str(&format!(" {}:1", bos_token.unwrap())); + if let Some(bos) = bos_token { + single.push(format!("{}:1", bos)); + } } - pair.push_str(" $B:1"); + pair.push("$B:1".to_string()); if add_eos_token { - pair.push_str(&format!(" {}:1", eos_token.unwrap())); + if let Some(eos) = eos_token { + pair.push(format!("{}:1", eos)); + } } let post_processor = TemplateProcessing::builder() - .try_single(single) - .unwrap() - .try_pair(pair) - .unwrap() + .try_single(single)? + .try_pair(pair)? .special_tokens(special_tokens) - .build() - .unwrap(); + .build()?; - Some(post_processor) + Ok(post_processor) } #[derive(Debug, Error)] @@ -626,9 +629,9 @@ mod tests { let post_processor = create_post_processor(&tokenizer, &tokenizer_config).unwrap(); let expected = TemplateProcessing::builder() - .try_single(":0 $A:0") + .try_single(":0 $A:0 :1") .unwrap() - .try_pair(":0 $A:0 :1 $B:1") + .try_pair(":0 $A:0 $B:1") .unwrap() .special_tokens(vec![("".to_string(), 1)]) .build()