diff --git a/tokenizers/src/decoders/sequence.rs b/tokenizers/src/decoders/sequence.rs index 484df6c95..73169b695 100644 --- a/tokenizers/src/decoders/sequence.rs +++ b/tokenizers/src/decoders/sequence.rs @@ -13,6 +13,14 @@ impl Sequence { pub fn new(decoders: Vec) -> Self { Self { decoders } } + + pub fn get_decoders(&self) -> &[DecoderWrapper] { + &self.decoders + } + + pub fn get_decoders_mut(&mut self) -> &mut [DecoderWrapper] { + &mut self.decoders + } } impl Decoder for Sequence { diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index ae6a64362..01a598423 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -384,6 +384,12 @@ where self } + /// Set the added vocabulary. + pub fn with_added_vocabulary(mut self, added_vocabulary: AddedVocabulary) -> Self { + self.added_vocabulary = added_vocabulary; + self + } + /// Set the trunaction parameters. #[must_use] pub fn with_truncation(mut self, trunc: Option) -> Self { @@ -598,6 +604,17 @@ where &self.model } + /// Set the added vocabulary. + pub fn with_added_vocabulary(&mut self, added_vocabulary: AddedVocabulary) -> &mut Self { + self.added_vocabulary = added_vocabulary.into(); + self + } + + /// Get the added vocabulary + pub fn get_added_vocabulary(&self) -> &AddedVocabulary { + &self.added_vocabulary + } + /// Set the truncation parameters /// /// Fails if `stride` is too high relative to `max_length` and `post_processor.added_tokens()`