Skip to content

Commit

Permalink
fix: refactor post_processor logic and add test (#2137)
Browse files Browse the repository at this point in the history
* fix: refactor post_processor logic and add test

* fix: remove dev comment

* fix: adjust when post_processor is overridden and  improve create_post_processor
  • Loading branch information
drbh authored Jun 27, 2024
1 parent 3ea8259 commit 74b0231
Showing 1 changed file with 117 additions and 29 deletions.
146 changes: 117 additions & 29 deletions router/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -304,37 +304,21 @@ async fn main() -> Result<(), RouterError> {
tracing::warn!("Could not find tokenizer config locally and no API specified");
HubTokenizerConfig::default()
});
let tokenizer: Option<Tokenizer> =
tokenizer_filename.and_then(|filename| {
let mut tokenizer = Tokenizer::from_file(filename).ok();
if let Some(tokenizer) = &mut tokenizer{
if let Some(class) = &tokenizer_config.tokenizer_class{
if class == "LlamaTokenizer" || class == "LlamaTokenizerFast" {
tracing::info!("Overriding LllamaTokenizer with TemplateProcessing to follow python override defined in https://github.com/huggingface/transformers/blob/4aa17d00690b7f82c95bb2949ea57e22c35b4336/src/transformers/models/llama/tokenization_llama_fast.py#L203-L205");
let mut single = vec![];
let mut special_tokens = vec![];
if let Some(true) = &tokenizer_config.add_bos_token{
if let Some(bos_token) = &tokenizer_config.bos_token{
let bos_token_id = tokenizer.token_to_id(&bos_token).expect("Should have found the bos token id");
special_tokens.push((bos_token.clone(), bos_token_id));
single.push(bos_token.to_string());
}
}
single.push("$0".to_string());
if let Some(true) = &tokenizer_config.add_eos_token{
if let Some(eos_token) = &tokenizer_config.eos_token{
let eos_token_id = tokenizer.token_to_id(&eos_token).expect("Should have found the eos token id");
special_tokens.push((eos_token.clone(), eos_token_id));
single.push(eos_token.to_string());
}
}
let post_processor = TemplateProcessing::builder().try_single(single).unwrap().special_tokens(special_tokens).build().unwrap();

let tokenizer: Option<Tokenizer> = tokenizer_filename.and_then(|filename| {
let mut tokenizer = Tokenizer::from_file(filename).ok();
if let Some(tokenizer) = &mut tokenizer {
if let Some(class) = &tokenizer_config.tokenizer_class {
if (class == "LlamaTokenizer" || class == "LlamaTokenizerFast") && tokenizer.get_post_processor().is_none() {
if let Ok(post_processor) = create_post_processor(tokenizer, &tokenizer_config) {
tracing::info!("Overriding LlamaTokenizer with TemplateProcessing to follow python override defined in https://github.com/huggingface/transformers/blob/4aa17d00690b7f82c95bb2949ea57e22c35b4336/src/transformers/models/llama/tokenization_llama_fast.py#L203-L205");
tokenizer.with_post_processor(post_processor);
}}
}
}
tokenizer

});
}
}
tokenizer
});

let preprocessor_config =
preprocessor_config_filename.and_then(HubPreprocessorConfig::from_file);
Expand Down Expand Up @@ -543,6 +527,77 @@ pub async fn get_tokenizer_config(api_repo: &ApiRepo) -> Option<HubTokenizerConf
Some(tokenizer_config)
}

/// Create a post_processor for the LlamaTokenizer
pub fn create_post_processor(
tokenizer: &Tokenizer,
tokenizer_config: &HubTokenizerConfig,
) -> Result<TemplateProcessing, tokenizers::processors::template::TemplateProcessingBuilderError> {
let add_bos_token = tokenizer_config.add_bos_token.unwrap_or(true);
let add_eos_token = tokenizer_config.add_eos_token.unwrap_or(false);

let bos_token = tokenizer_config.bos_token.as_ref();
let eos_token = tokenizer_config.eos_token.as_ref();

if add_bos_token && bos_token.is_none() {
panic!("add_bos_token = true but bos_token is None");
}

if add_eos_token && eos_token.is_none() {
panic!("add_eos_token = true but eos_token is None");
}

let mut single = Vec::new();
let mut pair = Vec::new();
let mut special_tokens = Vec::new();

if add_bos_token {
if let Some(bos) = bos_token {
let bos_token_id = tokenizer
.token_to_id(bos)
.expect("Should have found the bos token id");
special_tokens.push((bos.clone(), bos_token_id));
single.push(format!("{}:0", bos));
pair.push(format!("{}:0", bos));
}
}

single.push("$A:0".to_string());
pair.push("$A:0".to_string());

if add_eos_token {
if let Some(eos) = eos_token {
let eos_token_id = tokenizer
.token_to_id(eos)
.expect("Should have found the eos token id");
special_tokens.push((eos.clone(), eos_token_id));
single.push(format!("{}:0", eos));
pair.push(format!("{}:0", eos));
}
}

if add_bos_token {
if let Some(bos) = bos_token {
single.push(format!("{}:1", bos));

This comment has been minimized.

Copy link
@sywangyi

sywangyi Jun 28, 2024

Contributor

should be pair.push() instead of single.push() here. and @Narsil , I meet similar problem when enabling microsoft/Phi-3-mini-4k-instruct and I find it crash in batch.slots[batch.slot_indices](out of range). I dived into the issue and root caused to that the tokenizer in RUST return different tokenizer_input comparing with batch_tokenized_inputs in flash_causal_lm.py. lacking of bos, so make the slot allocation mismatch because total token calculation is different between python3 and rust layer. Do you have idea about how to fix this issue, seems tokenizer rust has some bug in postprocessing.

This comment has been minimized.

Copy link
@sywangyi

sywangyi Jul 1, 2024

Contributor

I file a PR to fix phi3 issue. please help check. #2148

}
}

pair.push("$B:1".to_string());

if add_eos_token {
if let Some(eos) = eos_token {
pair.push(format!("{}:1", eos));
}
}

let post_processor = TemplateProcessing::builder()
.try_single(single)?
.try_pair(pair)?
.special_tokens(special_tokens)
.build()?;

Ok(post_processor)
}

#[derive(Debug, Error)]
enum RouterError {
#[error("Argument validation error: {0}")]
Expand All @@ -552,3 +607,36 @@ enum RouterError {
#[error("Tokio runtime failed to start: {0}")]
Tokio(#[from] std::io::Error),
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_create_post_processor() {
let tokenizer_config = HubTokenizerConfig {
add_bos_token: None,
add_eos_token: None,
bos_token: Some("<s>".to_string()),
eos_token: Some("</s>".to_string()),
chat_template: None,
tokenizer_class: None,
completion_template: None,
};

let tokenizer =
Tokenizer::from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", None).unwrap();
let post_processor = create_post_processor(&tokenizer, &tokenizer_config).unwrap();

let expected = TemplateProcessing::builder()
.try_single("<s>:0 $A:0 <s>:1")
.unwrap()
.try_pair("<s>:0 $A:0 $B:1")
.unwrap()
.special_tokens(vec![("<s>".to_string(), 1)])
.build()
.unwrap();

assert_eq!(post_processor, expected);
}
}

0 comments on commit 74b0231

Please sign in to comment.