Skip to content

Commit

Permalink
Fix encode_batch and encode_batch_fast to accept ndarrays again
Browse files Browse the repository at this point in the history
  • Loading branch information
Dimitris Iliopoulos committed Nov 8, 2024
1 parent 5aa9f6c commit 99a4dc1
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 16 deletions.
28 changes: 14 additions & 14 deletions bindings/python/src/tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -407,10 +407,10 @@ impl<'s> FromPyObject<'s> for TextEncodeInput<'s> {
if let Ok((i1, i2)) = ob.extract::<(TextInputSequence, TextInputSequence)>() {
return Ok(Self((i1, i2).into()));
}
if let Ok(arr) = ob.downcast::<PyList>() {
if let Ok(arr) = ob.extract::<Vec<Bound<PyAny>>>() {
if arr.len() == 2 {
let first = arr.get_item(0)?.extract::<TextInputSequence>()?;
let second = arr.get_item(1)?.extract::<TextInputSequence>()?;
let first = arr[0].extract::<TextInputSequence>()?;
let second = arr[1].extract::<TextInputSequence>()?;
return Ok(Self((first, second).into()));
}
}
Expand All @@ -434,10 +434,10 @@ impl<'s> FromPyObject<'s> for PreTokenizedEncodeInput<'s> {
{
return Ok(Self((i1, i2).into()));
}
if let Ok(arr) = ob.downcast::<PyList>() {
if let Ok(arr) = ob.extract::<Vec<Bound<PyAny>>>() {
if arr.len() == 2 {
let first = arr.get_item(0)?.extract::<PreTokenizedInputSequence>()?;
let second = arr.get_item(1)?.extract::<PreTokenizedInputSequence>()?;
let first = arr[0].extract::<PreTokenizedInputSequence>()?;
let second = arr[1].extract::<PreTokenizedInputSequence>()?;
return Ok(Self((first, second).into()));
}
}
Expand Down Expand Up @@ -1032,13 +1032,13 @@ impl PyTokenizer {
fn encode_batch(
&self,
py: Python<'_>,
input: Bound<'_, PySequence>,
input: Vec<Bound<'_, PyAny>>,
is_pretokenized: bool,
add_special_tokens: bool,
) -> PyResult<Vec<PyEncoding>> {
let mut items = Vec::<tk::EncodeInput>::with_capacity(input.len()?);
for i in 0..input.len()? {
let item = input.get_item(i)?;
let mut items = Vec::<tk::EncodeInput>::with_capacity(input.len());
for i in 0..input.len() {
let item = &input[i];
let item: tk::EncodeInput = if is_pretokenized {
item.extract::<PreTokenizedEncodeInput>()?.into()
} else {
Expand Down Expand Up @@ -1092,13 +1092,13 @@ impl PyTokenizer {
fn encode_batch_fast(
&self,
py: Python<'_>,
input: Bound<'_, PySequence>,
input: Vec<Bound<'_, PyAny>>,
is_pretokenized: bool,
add_special_tokens: bool,
) -> PyResult<Vec<PyEncoding>> {
let mut items = Vec::<tk::EncodeInput>::with_capacity(input.len()?);
for i in 0..input.len()? {
let item = input.get_item(i)?;
let mut items = Vec::<tk::EncodeInput>::with_capacity(input.len());
for i in 0..input.len() {
let item = &input[i];
let item: tk::EncodeInput = if is_pretokenized {
item.extract::<PreTokenizedEncodeInput>()?.into()
} else {
Expand Down
2 changes: 0 additions & 2 deletions bindings/python/tests/bindings/test_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,6 @@ def test_encode(self):
assert len(output) == 2

def test_encode_formats(self, bert_files):
print("Broken by the change from std::usize::Max to usixeMax")
return 0
with pytest.deprecated_call():
tokenizer = BertWordPieceTokenizer(bert_files["vocab"])

Expand Down

0 comments on commit 99a4dc1

Please sign in to comment.