Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
ArthurZucker committed Aug 5, 2024
1 parent 024ccc2 commit dc41e37
Showing 1 changed file with 70 additions and 3 deletions.
73 changes: 70 additions & 3 deletions tokenizers/src/models/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@ pub mod unigram;
pub mod wordlevel;
pub mod wordpiece;

use serde::{
self,
de::{value::MapAccessDeserializer, Error, MapAccess, Visitor},
ser::SerializeStruct,

Check warning on line 11 in tokenizers/src/models/mod.rs

View workflow job for this annotation

GitHub Actions / Check it builds for Windows 32-bit (3.8)

unused import: `ser::SerializeStruct`

Check warning on line 11 in tokenizers/src/models/mod.rs

View workflow job for this annotation

GitHub Actions / Check it builds for Windows 32-bit (3.9)

unused import: `ser::SerializeStruct`

Check warning on line 11 in tokenizers/src/models/mod.rs

View workflow job for this annotation

GitHub Actions / Check it builds for Windows 32-bit (3.7)

unused import: `ser::SerializeStruct`

Check warning on line 11 in tokenizers/src/models/mod.rs

View workflow job for this annotation

GitHub Actions / Check it builds for Windows 32-bit (3.10)

unused import: `ser::SerializeStruct`

Check warning on line 11 in tokenizers/src/models/mod.rs

View workflow job for this annotation

GitHub Actions / Check everything builds

unused import: `ser::SerializeStruct`

Check warning on line 11 in tokenizers/src/models/mod.rs

View workflow job for this annotation

GitHub Actions / Check everything builds & tests (ubuntu-latest)

unused import: `ser::SerializeStruct`
Deserialize, Deserializer, Serialize, Serializer,
};
use std::collections::HashMap;
use std::path::{Path, PathBuf};

use serde::{Deserialize, Serialize, Serializer};

use crate::models::bpe::{BpeTrainer, BPE};
use crate::models::unigram::{Unigram, UnigramTrainer};
use crate::models::wordlevel::{WordLevel, WordLevelTrainer};
Expand Down Expand Up @@ -57,7 +61,8 @@ impl<'a> Serialize for OrderedVocabIter<'a> {
}
}

#[derive(Deserialize, Serialize, Debug, PartialEq, Clone)]
#[derive(Serialize, Debug, PartialEq, Clone)]
#[serde(untagged)]
pub enum ModelWrapper {
BPE(BPE),
// WordPiece must stay before WordLevel here for deserialization (for retrocompatibility
Expand All @@ -67,6 +72,68 @@ pub enum ModelWrapper {
Unigram(Unigram),
}

impl<'de> Deserialize<'de> for ModelWrapper {
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where
D: Deserializer<'de>,
{
// Define a Visitor struct that will handle both tagged and untagged deserialization
struct ModelWrapperVisitor;

impl<'de> Visitor<'de> for ModelWrapperVisitor {
type Value = ModelWrapper;

fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a valid ModelWrapper")
}

fn visit_map<V>(self, mut map: V) -> std::result::Result<ModelWrapper, V::Error>
where
V: MapAccess<'de>,
{
println!("Parsed stuff????");

// Look for a "type" tag
if let Some(key) = map.next_key::<String>()? {
println!("Matched key: {:?}", key);
match key.as_str() {
"type" => {
let tag: String = map.next_value()?;
println!("Matched tag: {:?}", tag);

match tag.as_str() {
"BPE" => {
BPE::deserialize(MapAccessDeserializer::new(map))?;
}
"WordPiece" => {
WordPiece::deserialize(MapAccessDeserializer::new(map))?;
}
"WordLevel" => {
WordLevel::deserialize(MapAccessDeserializer::new(map))?;
}
"Unigram" => {
Unigram::deserialize(MapAccessDeserializer::new(map))?;
}
_ => {
return Err(V::Error::unknown_variant(
&tag,
&["BPE", "WordPiece", "WordLevel", "Unigram"],
));
}
}
}
_ => todo!(),
}
}

Err(V::Error::custom("invalid input"))
}
}

deserializer.deserialize_any(ModelWrapperVisitor)
}
}

impl_enum_from!(WordLevel, ModelWrapper, WordLevel);
impl_enum_from!(WordPiece, ModelWrapper, WordPiece);
impl_enum_from!(BPE, ModelWrapper, BPE);
Expand Down

0 comments on commit dc41e37

Please sign in to comment.