From 4bc7b4cc2a40586f20f7c4a9c656fc2f6fc20889 Mon Sep 17 00:00:00 2001 From: Dimitris Iliopoulos Date: Thu, 21 Nov 2024 05:55:11 -0500 Subject: [PATCH] Fix encode_batch and encode_batch_fast to accept ndarrays again (#1679) * Fix encode_batch and encode_batch_fast to accept ndarrays again * Fix clippy --------- Co-authored-by: Dimitris Iliopoulos --- bindings/python/src/tokenizer.rs | 26 +++++++++---------- .../python/tests/bindings/test_tokenizer.py | 2 -- 2 files changed, 12 insertions(+), 16 deletions(-) diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index 09fb891e1..52b86d975 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -408,10 +408,10 @@ impl<'s> FromPyObject<'s> for TextEncodeInput<'s> { if let Ok((i1, i2)) = ob.extract::<(TextInputSequence, TextInputSequence)>() { return Ok(Self((i1, i2).into())); } - if let Ok(arr) = ob.downcast::() { + if let Ok(arr) = ob.extract::>>() { if arr.len() == 2 { - let first = arr.get_item(0)?.extract::()?; - let second = arr.get_item(1)?.extract::()?; + let first = arr[0].extract::()?; + let second = arr[1].extract::()?; return Ok(Self((first, second).into())); } } @@ -435,10 +435,10 @@ impl<'s> FromPyObject<'s> for PreTokenizedEncodeInput<'s> { { return Ok(Self((i1, i2).into())); } - if let Ok(arr) = ob.downcast::() { + if let Ok(arr) = ob.extract::>>() { if arr.len() == 2 { - let first = arr.get_item(0)?.extract::()?; - let second = arr.get_item(1)?.extract::()?; + let first = arr[0].extract::()?; + let second = arr[1].extract::()?; return Ok(Self((first, second).into())); } } @@ -1033,13 +1033,12 @@ impl PyTokenizer { fn encode_batch( &self, py: Python<'_>, - input: Bound<'_, PySequence>, + input: Vec>, is_pretokenized: bool, add_special_tokens: bool, ) -> PyResult> { - let mut items = Vec::::with_capacity(input.len()?); - for i in 0..input.len()? { - let item = input.get_item(i)?; + let mut items = Vec::::with_capacity(input.len()); + for item in &input { let item: tk::EncodeInput = if is_pretokenized { item.extract::()?.into() } else { @@ -1093,13 +1092,12 @@ impl PyTokenizer { fn encode_batch_fast( &self, py: Python<'_>, - input: Bound<'_, PySequence>, + input: Vec>, is_pretokenized: bool, add_special_tokens: bool, ) -> PyResult> { - let mut items = Vec::::with_capacity(input.len()?); - for i in 0..input.len()? { - let item = input.get_item(i)?; + let mut items = Vec::::with_capacity(input.len()); + for item in &input { let item: tk::EncodeInput = if is_pretokenized { item.extract::()?.into() } else { diff --git a/bindings/python/tests/bindings/test_tokenizer.py b/bindings/python/tests/bindings/test_tokenizer.py index ffa86f1be..d50f283e7 100644 --- a/bindings/python/tests/bindings/test_tokenizer.py +++ b/bindings/python/tests/bindings/test_tokenizer.py @@ -153,8 +153,6 @@ def test_encode(self): assert len(output) == 2 def test_encode_formats(self, bert_files): - print("Broken by the change from std::usize::Max to usixeMax") - return 0 with pytest.deprecated_call(): tokenizer = BertWordPieceTokenizer(bert_files["vocab"])