diff --git a/tokenizers/src/models/bpe/serialization.rs b/tokenizers/src/models/bpe/serialization.rs index e0fc1f022..7b45d94a3 100644 --- a/tokenizers/src/models/bpe/serialization.rs +++ b/tokenizers/src/models/bpe/serialization.rs @@ -28,14 +28,14 @@ impl Serialize for BPE { .map(|(pair, (rank, _))| (pair, rank)) .collect(); merges.sort_unstable_by_key(|k| *k.1); - let merges_str = merges + let merges = merges .into_iter() - .map(|(pair, _)| format!("{} {}", self.vocab_r[&pair.0], self.vocab_r[&pair.1])) + .map(|(pair, _)| (self.vocab_r[&pair.0].clone(), self.vocab_r[&pair.1].clone())) .collect::>(); let ordered_vocab = OrderedVocabIter::new(&self.vocab_r); model.serialize_field("vocab", &ordered_vocab)?; - model.serialize_field("merges", &merges_str)?; + model.serialize_field("merges", &merges)?; model.end() } @@ -77,7 +77,14 @@ impl<'de> Visitor<'de> for BPEVisitor { { let mut builder = BpeBuilder::new(); let mut vocab: Option> = None; - let mut merges: Option> = None; + + #[derive(Debug, Deserialize)] + #[serde(untagged)] + enum MergeType { + Tuple(Vec<(String, String)>), + Legacy(Vec), + } + let mut merges: Option = None; while let Some(key) = map.next_key::()? { match key.as_ref() { "dropout" => { @@ -120,8 +127,12 @@ impl<'de> Visitor<'de> for BPEVisitor { } } if let (Some(vocab), Some(merges)) = (vocab, merges) { - let merges = - convert_merges_to_hashmap(merges.into_iter(), &vocab).map_err(Error::custom)?; + let merges = match merges { + MergeType::Tuple(merges) => merges, + MergeType::Legacy(merges) => { + convert_merges_to_hashmap(merges.into_iter(), &vocab).map_err(Error::custom)? + } + }; builder = builder.vocab_and_merges(vocab, merges); Ok(builder.build().map_err(Error::custom)?) } else {