diff --git a/machine/corpora/parallel_text_corpus.py b/machine/corpora/parallel_text_corpus.py index 5f17c97..f06a4ec 100644 --- a/machine/corpora/parallel_text_corpus.py +++ b/machine/corpora/parallel_text_corpus.py @@ -617,8 +617,8 @@ def count(self, include_empty: bool = True, text_ids: Optional[Iterable[str]] = if include_empty: return len(self._df) return len(self._df[(self._df[self._source_column] != "") & (self._df[self._target_column] != "")]) - return len(self._df[self._df[self._source_column].isin(text_ids)]) & ( - len(self._df[self._target_column].isin(text_ids)) + return len(self._df[self._df[self._source_column].isin(cast(Sequence[str], text_ids))]) & ( + len(self._df[self._target_column].isin(cast(Sequence[str], text_ids)) ) def _get_rows(self, text_ids: Optional[Iterable[str]] = None) -> Generator[ParallelTextRow, None, None]: