Skip to content

Commit

Permalink
Refactor datacleaner
Browse files Browse the repository at this point in the history
  • Loading branch information
lukasgarbas committed Nov 22, 2024
1 parent 67b5f14 commit 28938a1
Show file tree
Hide file tree
Showing 3 changed files with 258 additions and 277 deletions.
31 changes: 17 additions & 14 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
import torch
from datasets import Dataset
from datasets import Dataset, load_dataset
from transformer_ranker.datacleaner import DatasetCleaner


Expand All @@ -19,31 +19,34 @@
("text classification", "SetFit/rte", 0.05),
]


@pytest.mark.parametrize("task_type,dataset_name,downsampling_ratio", test_datasets)
def test_datacleaner(task_type, dataset_name, downsampling_ratio):

preprocessor = DatasetCleaner(dataset_downsample=downsampling_ratio)
dataset = preprocessor.prepare_dataset(dataset_name)
dataset = load_dataset(dataset_name, trust_remote_code=True)
datacleaner = DatasetCleaner(dataset_downsample=downsampling_ratio)
dataset = datacleaner.prepare_dataset(dataset)

# Test dataset preprocessing
assert isinstance(dataset, Dataset), f"Dataset '{dataset_name}' is not a valid Dataset object"

assert preprocessor.task_type == task_type, (
f"Task type mismatch: expected '{task_type}', got '{preprocessor.task_type}'"
assert dataset.task_category == task_type, (
f"Task type mismatch: expected '{task_type}', got '{dataset.task_category}'"
f"in dataset '{dataset_name}'"
)

# Make sure text and label columns were found
assert preprocessor.text_column is not None, f"Text column not found in dataset {dataset_name}"
assert preprocessor.label_column is not None, f"Label column not found in dataset {dataset_name}"
assert dataset.text_column is not None, f"Text column not found in dataset {dataset_name}"
assert dataset.label_column is not None, f"Label column not found in dataset {dataset_name}"

# Test texts in the text column
sentences = preprocessor.prepare_sentences(dataset)
# Prepare texts using .texts()
sentences = dataset.texts()
assert isinstance(sentences, list) and len(sentences) > 0, (
"Sentences/tokens list is empty in dataset %s", dataset_name
)

# Ensure the sentences are in the correct format (str for text-classification, List[str] for token-level)
# Ensure the sentences are in the correct format
# (str for text-classification, List[str] for token-level)
if task_type == "text classification":
for sentence in sentences:
assert isinstance(sentence, str), (
Expand All @@ -69,7 +72,7 @@ def test_datacleaner(task_type, dataset_name, downsampling_ratio):
raise KeyError(msg)

# Test the label column in each dataset
labels = preprocessor.prepare_labels(dataset)
labels = dataset.labels()
assert isinstance(labels, torch.Tensor) and labels.size(0) > 0, "Labels tensor is empty"
assert (labels >= 0).all(), f"Negative label found in dataset {dataset_name}"

Expand All @@ -81,7 +84,7 @@ def test_simple_dataset():
"something_else": [0, 1, 2, 3, 4, 5]
})

preprocessor = DatasetCleaner(dataset_downsample=0.5, remove_empty_sentences=False)
preprocessor = DatasetCleaner(dataset_downsample=0.5, cleanup_rows=False)
dataset = preprocessor.prepare_dataset(original_dataset)

assert len(original_dataset) == 6
Expand All @@ -90,13 +93,13 @@ def test_simple_dataset():
assert len(dataset) == 3
assert dataset.column_names == ["text", "label"]

preprocessor = DatasetCleaner(remove_empty_sentences=False)
preprocessor = DatasetCleaner(cleanup_rows=False)
dataset = preprocessor.prepare_dataset(original_dataset)

assert dataset["label"] == [0, 1, 2, 0, 1, 2]
assert original_dataset["label"] == ["X", "Y", "Z", "X", "Y", "Z"]

preprocessor = DatasetCleaner(remove_empty_sentences=True)
preprocessor = DatasetCleaner(cleanup_rows=True)
dataset = preprocessor.prepare_dataset(original_dataset)

# One row should have been removed in the processed dataset
Expand Down
Loading

0 comments on commit 28938a1

Please sign in to comment.