From d8aa049953052f0b25d434f6aca30b9850eac2c2 Mon Sep 17 00:00:00 2001 From: TaperChipmunk32 Date: Tue, 22 Oct 2024 16:13:27 -0500 Subject: [PATCH] Fixed bug in parallel_text_corpus --- machine/corpora/parallel_text_corpus.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/machine/corpora/parallel_text_corpus.py b/machine/corpora/parallel_text_corpus.py index 17a5eb0..61b4e94 100644 --- a/machine/corpora/parallel_text_corpus.py +++ b/machine/corpora/parallel_text_corpus.py @@ -588,8 +588,8 @@ def count(self, include_empty: bool = True, text_ids: Optional[Iterable[str]] = if include_empty: return len(self._df) return len(self._df[(self._df[self._source_column] != "") & (self._df[self._target_column] != "")]) - return len(self._df[self._df[self._source_column].isin(text_ids)]) & ( - len(self._df[self._target_column].isin(text_ids)) + return len(self._df[self._df[self._source_column].isin(cast(Sequence[str], text_ids))]) & ( + len(self._df[self._target_column].isin(cast(Sequence[str], text_ids))) ) def _get_rows(self, text_ids: Optional[Iterable[str]] = None) -> Generator[ParallelTextRow, None, None]: