Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
ArthurZucker committed Aug 6, 2024
1 parent fe41687 commit bfb6222
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 3 deletions.
7 changes: 4 additions & 3 deletions tokenizers/src/normalizers/byte_level.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@ use crate::processors::byte_level::bytes_char;
use crate::tokenizer::{NormalizedString, Normalizer, Result};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use crate::utils::macro_rules_attribute;

#[derive(Clone, Debug, Deserialize, Serialize)]
#[serde(tag = "type")]
pub struct ByteLevel {}
#[derive(Clone, Debug)]
#[macro_rules_attribute(impl_serde_type!)]
pub struct ByteLevel;

lazy_static! {
static ref BYTES_CHAR: HashMap<u8, char> = bytes_char();
Expand Down
33 changes: 33 additions & 0 deletions tokenizers/src/normalizers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,36 @@ impl_enum_from!(Precompiled, NormalizerWrapper, Precompiled);
impl_enum_from!(Replace, NormalizerWrapper, Replace);
impl_enum_from!(Prepend, NormalizerWrapper, Prepend);
impl_enum_from!(ByteLevel, NormalizerWrapper, ByteLevel);

#[cfg(test)]
mod tests {
use super::*;
#[test]
fn post_processor_deserialization_no_type() {
let json = r#"{"strip_left":false, "strip_right":true}"#;
let reconstructed = serde_json::from_str::<NormalizerWrapper>(json);
assert!(matches!(
reconstructed.unwrap(),
NormalizerWrapper::StripNormalizer(_)
));

let json =
r#"{"sep":["</s>",2], "cls":["<s>",0], "trim_offsets":true, "add_prefix_space":true}"#;
let reconstructed = serde_json::from_str::<NormalizerWrapper>(json).unwrap();
println!("{:?}", reconstructed);
assert!(matches!(
reconstructed,
NormalizerWrapper::Sequence(_)
));

let json = r#"{"type":"RobertaProcessing", "sep":["</s>",2] }"#;
let reconstructed = serde_json::from_str::<NormalizerWrapper>(json);
match reconstructed {
Err(err) => assert_eq!(
err.to_string(),
"data did not match any variant of untagged enum NormalizerWrapper"
),
_ => panic!("Expected an error here"),
}
}
}

0 comments on commit bfb6222

Please sign in to comment.