Skip to content

Commit

Permalink
fix encode batch fast as well
Browse files Browse the repository at this point in the history
  • Loading branch information
ArthurZucker committed Nov 5, 2024
1 parent 495d430 commit 8c96f47
Showing 1 changed file with 12 additions and 13 deletions.
25 changes: 12 additions & 13 deletions bindings/python/src/tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1092,25 +1092,24 @@ impl PyTokenizer {
fn encode_batch_fast(
&self,
py: Python<'_>,
input: Bound<'_, PyList>,
input: Bound<'_, PySequence>,
is_pretokenized: bool,
add_special_tokens: bool,
) -> PyResult<Vec<PyEncoding>> {
let input: Vec<tk::EncodeInput> = input
.into_iter()
.map(|o| {
let input: tk::EncodeInput = if is_pretokenized {
o.extract::<PreTokenizedEncodeInput>()?.into()
} else {
o.extract::<TextEncodeInput>()?.into()
};
Ok(input)
})
.collect::<PyResult<Vec<tk::EncodeInput>>>()?;
let mut items = Vec::<tk::EncodeInput>::new();
for i in 0..input.len()? {
let item = input.get_item(i)?;
let item: tk::EncodeInput = if is_pretokenized {
item.extract::<PreTokenizedEncodeInput>()?.into()
} else {
item.extract::<TextEncodeInput>()?.into()
};
items.push(item);
}
py.allow_threads(|| {
ToPyResult(
self.tokenizer
.encode_batch_fast(input, add_special_tokens)
.encode_batch_fast(items, add_special_tokens)
.map(|encodings| encodings.into_iter().map(|e| e.into()).collect()),
)
.into()
Expand Down

0 comments on commit 8c96f47

Please sign in to comment.