Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added ability to inspect a 'Sequence' decoder and the AddedVocabulary. #1443

Merged
merged 2 commits into from
Mar 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions tokenizers/src/decoders/sequence.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@ impl Sequence {
pub fn new(decoders: Vec<DecoderWrapper>) -> Self {
Self { decoders }
}

pub fn get_decoders(&self) -> &[DecoderWrapper] {
&self.decoders
}

pub fn get_decoders_mut(&mut self) -> &mut [DecoderWrapper] {
&mut self.decoders
}
}

impl Decoder for Sequence {
Expand Down
13 changes: 12 additions & 1 deletion tokenizers/src/tokenizer/added_vocabulary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ fn space_rightmost_at_start(sentence: &str) -> usize {
/// exist as required.
///
#[derive(Clone, Debug)]
pub(super) struct AddedVocabulary {
pub struct AddedVocabulary {
/// Contains the mapping from String (token content) to ID. This map contains both special
/// tokens and classic added tokens that were added to the this vocabulary.
added_tokens_map: HashMap<String, u32>,
Expand Down Expand Up @@ -192,6 +192,11 @@ impl AddedVocabulary {
self.added_tokens_map.len()
}

/// Whether or not this vocabulary is empty
pub fn is_empty(&self) -> bool {
self.added_tokens_map.is_empty()
}

/// Get the additional vocabulary
pub fn get_vocab(&self) -> &HashMap<String, u32> {
&self.added_tokens_map
Expand Down Expand Up @@ -487,6 +492,12 @@ impl AddedVocabulary {
}
}

impl Default for AddedVocabulary {
fn default() -> Self {
Self::new()
}
}

#[derive(Debug, Serialize, Deserialize)]
pub(super) struct AddedTokenWithId {
/// The id assigned to this token
Expand Down
17 changes: 17 additions & 0 deletions tokenizers/src/tokenizer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,12 @@ where
self
}

/// Set the added vocabulary.
pub fn with_added_vocabulary(mut self, added_vocabulary: AddedVocabulary) -> Self {
self.added_vocabulary = added_vocabulary;
self
}

/// Set the trunaction parameters.
#[must_use]
pub fn with_truncation(mut self, trunc: Option<TruncationParams>) -> Self {
Expand Down Expand Up @@ -598,6 +604,17 @@ where
&self.model
}

/// Set the added vocabulary.
pub fn with_added_vocabulary(&mut self, added_vocabulary: AddedVocabulary) -> &mut Self {
self.added_vocabulary = added_vocabulary;
self
}

/// Get the added vocabulary
pub fn get_added_vocabulary(&self) -> &AddedVocabulary {
&self.added_vocabulary
}

/// Set the truncation parameters
///
/// Fails if `stride` is too high relative to `max_length` and `post_processor.added_tokens()`
Expand Down
Loading