Skip to content

Commit

Permalink
Added tests with two small models
Browse files Browse the repository at this point in the history
  • Loading branch information
lukasgarbas committed Oct 26, 2024
1 parent a25852b commit ea65cec
Show file tree
Hide file tree
Showing 11 changed files with 188 additions and 242 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@ pythonpath = [
]

[tool.mypy]
files="transformer_ranker,tests"
files="transformer_ranker"
ignore_missing_imports = true
check_untyped_defs = true
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ tokenizers
torch
torchmetrics
tqdm
transformers
transformers
scikit-learn
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def read_requirements():
name='transformer-ranker',
version='0.1.0',
packages=find_packages(),
description='Rank transformer models for NLP tasks using transferability measures',
description='Efficiently find the best-suited language model (LM) for your NLP task',
long_description=open('README.md').read(),
long_description_content_type="text/markdown",
author='Lukas Garbas',
Expand Down
35 changes: 35 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import pytest
import torch
from datasets import load_dataset
from sklearn import datasets
from transformers import AutoModel


@pytest.fixture(scope="session")
def small_language_models():
"""Use two tiny models for testing"""
return (
AutoModel.from_pretrained("prajjwal1/bert-tiny"),
AutoModel.from_pretrained("google/electra-small-discriminator")
)


@pytest.fixture(scope="session")
def conll():
return load_dataset("conll2003")


@pytest.fixture(scope="session")
def trec():
return load_dataset("trec")


@pytest.fixture(scope="session")
def iris_dataset():
iris = datasets.load_iris()
data = torch.tensor(iris["data"], dtype=torch.float32)
data[142] += torch.tensor([0, 0, 0, 0.01]) # Ensure no exact duplicates
return {
"data": data,
"target": torch.tensor(iris["target"], dtype=torch.float32)
}
100 changes: 0 additions & 100 deletions tests/datasets/test_datacleaner.py

This file was deleted.

48 changes: 0 additions & 48 deletions tests/datasets/test_labels.py

This file was deleted.

72 changes: 0 additions & 72 deletions tests/datasets/test_sentences.py

This file was deleted.

19 changes: 0 additions & 19 deletions tests/estimators/conftest.py

This file was deleted.

86 changes: 86 additions & 0 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from typing import List, Tuple, Type, Union

import pytest
import torch
from datasets import Dataset
from transformer_ranker.datacleaner import DatasetCleaner


def load_datasets(dataset_type: str, num_datasets: Union[str, int] = "all") -> Tuple[List[str], Type, Type]:
"""Try loading and preparing different datasets"""
dataset_map = {
'token': (
["conll2003", "wnut_17"],
"token classification", list
),
'text': (
["trec", "stanfordnlp/sst2", "hate_speech18"],
"text classification", str
),
'text_pair': (
["yangwang825/sick", "SetFit/rte"],
"text classification", str
)
}

datasets, task_type, sentence_type = dataset_map[dataset_type]
if isinstance(num_datasets, int):
datasets = datasets[:num_datasets]

return datasets, task_type, sentence_type


def validate_dataset(
preprocessor,
dataset_name: str,
dataset: Dataset,
expected_task_type: Type,
sentence_type: Type
):
assert isinstance(dataset, Dataset), f"Dataset '{dataset_name}' is not a valid Dataset object"

assert preprocessor.task_type == expected_task_type, \
(f"Task type mismatch: expected '{expected_task_type}', got '{preprocessor.task_type}'"
f"in dataset '{dataset_name}'")

# Make sure text and label columns were found
assert preprocessor.text_column is not None, f"Text column not found in dataset {dataset_name}"
assert preprocessor.label_column is not None, f"Label column not found in dataset {dataset_name}"

# Test texts in the text column
sentences = preprocessor.prepare_sentences(dataset)
assert isinstance(sentences, list) and len(sentences) > 0, (
"Sentences/tokens list is empty in dataset %s", dataset_name
)
assert all(isinstance(sentence, sentence_type) for sentence in sentences), \
(f"Incorrect sentence/token type in dataset '{dataset_name}', all expected to be '{sentence_type}' "
f"but some sentences have different type")

if sentence_type == str:
# For text and text pair classification, make sure there's no empty strings
assert all(sentence != "" for sentence in sentences), f"Empty sentence found in dataset {dataset_name}"

if sentence_type == list:
# For token classification, make sure there is no empty tokens
assert all(sentence != [] for sentence in sentences), f"Empty token list found in dataset {dataset_name}"
# Check that no empty strings exist within the token lists
assert all(all(token != "" for token in sentence) for sentence in sentences), \
f"Empty token found within a sentence in dataset {dataset_name}"

# Test the label column in each dataset
labels = preprocessor.prepare_labels(dataset)
assert isinstance(labels, torch.Tensor) and labels.size(0) > 0, "Labels tensor is empty"
assert (labels >= 0).all(), f"Negative label found in dataset {dataset_name}"


@pytest.mark.parametrize("dataset_type", ["text", "token", "text_pair"])
def test_datacleaner(dataset_type):
datasets, task_type, sentence_type = load_datasets(dataset_type, "all")

# Loop through all test datasets, down sample them to 0.2
for dataset_name in datasets:
preprocessor = DatasetCleaner(dataset_downsample=0.2)
dataset = preprocessor.prepare_dataset(dataset_name)

# Test dataset preprocessing
validate_dataset(preprocessor, dataset_name, dataset, task_type, sentence_type)
Loading

0 comments on commit ea65cec

Please sign in to comment.