diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 8a380f6..c467464 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -1,6 +1,6 @@ import pytest import torch -from datasets import Dataset +from datasets import Dataset, load_dataset from transformer_ranker.datacleaner import DatasetCleaner @@ -19,31 +19,34 @@ ("text classification", "SetFit/rte", 0.05), ] + @pytest.mark.parametrize("task_type,dataset_name,downsampling_ratio", test_datasets) def test_datacleaner(task_type, dataset_name, downsampling_ratio): - preprocessor = DatasetCleaner(dataset_downsample=downsampling_ratio) - dataset = preprocessor.prepare_dataset(dataset_name) + dataset = load_dataset(dataset_name, trust_remote_code=True) + datacleaner = DatasetCleaner(dataset_downsample=downsampling_ratio) + dataset = datacleaner.prepare_dataset(dataset) # Test dataset preprocessing assert isinstance(dataset, Dataset), f"Dataset '{dataset_name}' is not a valid Dataset object" - assert preprocessor.task_type == task_type, ( - f"Task type mismatch: expected '{task_type}', got '{preprocessor.task_type}'" + assert dataset.task_category == task_type, ( + f"Task type mismatch: expected '{task_type}', got '{dataset.task_category}'" f"in dataset '{dataset_name}'" ) # Make sure text and label columns were found - assert preprocessor.text_column is not None, f"Text column not found in dataset {dataset_name}" - assert preprocessor.label_column is not None, f"Label column not found in dataset {dataset_name}" + assert dataset.text_column is not None, f"Text column not found in dataset {dataset_name}" + assert dataset.label_column is not None, f"Label column not found in dataset {dataset_name}" - # Test texts in the text column - sentences = preprocessor.prepare_sentences(dataset) + # Prepare texts using .texts() + sentences = dataset.texts() assert isinstance(sentences, list) and len(sentences) > 0, ( "Sentences/tokens list is empty in dataset %s", dataset_name ) - # Ensure the sentences are in the correct format (str for text-classification, List[str] for token-level) + # Ensure the sentences are in the correct format + # (str for text-classification, List[str] for token-level) if task_type == "text classification": for sentence in sentences: assert isinstance(sentence, str), ( @@ -69,7 +72,7 @@ def test_datacleaner(task_type, dataset_name, downsampling_ratio): raise KeyError(msg) # Test the label column in each dataset - labels = preprocessor.prepare_labels(dataset) + labels = dataset.labels() assert isinstance(labels, torch.Tensor) and labels.size(0) > 0, "Labels tensor is empty" assert (labels >= 0).all(), f"Negative label found in dataset {dataset_name}" @@ -81,7 +84,7 @@ def test_simple_dataset(): "something_else": [0, 1, 2, 3, 4, 5] }) - preprocessor = DatasetCleaner(dataset_downsample=0.5, remove_empty_sentences=False) + preprocessor = DatasetCleaner(dataset_downsample=0.5, cleanup_rows=False) dataset = preprocessor.prepare_dataset(original_dataset) assert len(original_dataset) == 6 @@ -90,13 +93,13 @@ def test_simple_dataset(): assert len(dataset) == 3 assert dataset.column_names == ["text", "label"] - preprocessor = DatasetCleaner(remove_empty_sentences=False) + preprocessor = DatasetCleaner(cleanup_rows=False) dataset = preprocessor.prepare_dataset(original_dataset) assert dataset["label"] == [0, 1, 2, 0, 1, 2] assert original_dataset["label"] == ["X", "Y", "Z", "X", "Y", "Z"] - preprocessor = DatasetCleaner(remove_empty_sentences=True) + preprocessor = DatasetCleaner(cleanup_rows=True) dataset = preprocessor.prepare_dataset(original_dataset) # One row should have been removed in the processed dataset diff --git a/transformer_ranker/datacleaner.py b/transformer_ranker/datacleaner.py index 8f71692..fc45b80 100644 --- a/transformer_ranker/datacleaner.py +++ b/transformer_ranker/datacleaner.py @@ -1,5 +1,7 @@ import logging -from typing import Optional, Type, Union +from dataclasses import dataclass +from enum import Enum +from typing import Optional, Union import datasets import torch @@ -11,193 +13,158 @@ logger = configure_logger("transformer_ranker", logging.INFO) -class DatasetCleaner: +class TaskCategory(str, Enum): + """Supported tasks""" + TEXT_REGRESSION = "text regression" + TEXT_CLASSIFICATION = "text classification" + TOKEN_CLASSIFICATION = "token classification" + + def __str__(self): + return self.value + + +class PreprocessedDataset(Dataset): + """A preprocessed dataset with only the required columns (texts and labels), + down-sampled and cleaned. Provides easy access to texts (sentences/words), labels (tensors), + and the task category (classification or regression).""" def __init__( self, - pre_tokenizer: Optional[Whitespace] = None, - merge_data_splits: bool = True, - remove_empty_sentences: bool = True, - change_bio_encoding: bool = True, - dataset_downsample: Optional[float] = None, - task_type: Optional[str] = None, - text_column: Optional[str] = None, - label_column: Optional[str] = None, - label_map: Optional[dict[str, int]] = None, - text_pair_column: Optional[str] = None, + dataset: Dataset, + text_column: str, + label_column: str, + task_category: TaskCategory, ): - """ - Prepare huggingface dataset. Identify task category, find text and label columns, - merge data splits, down-sample, prepare texts and labels. - - :param pre_tokenizer: Pre-tokenizer to use, such as Whitespace from huggingface. - :param merge_data_splits: Whether to merge train, dev, and test splits into one. - :param change_bio_encoding: Convert BIO to single-class labels, removing B-, I-, O- prefix. - :param remove_empty_sentences: Whether to remove empty sentences. - :param dataset_downsample: Fraction to reduce the dataset size. - :param task_type: "token classification", "text classification", or "text regression". - :param text_column: Column name for texts. - :param label_column: Column name for labels. - :param label_map: A dictionary which maps label names to integers. - :param text_pair_column: Column name for second text (for entailment tasks). - """ - self.pre_tokenizer = pre_tokenizer - self.merge_data_splits = merge_data_splits - self.change_bio_encoding = change_bio_encoding - self.remove_empty_sentences = remove_empty_sentences - self.dataset_downsample = dataset_downsample - self.task_type = task_type + super().__init__(dataset.data, dataset.info) + self.text_column = text_column self.label_column = label_column - self.label_map = label_map - self.text_pair_column = text_pair_column - self.dataset_size = 0 - - def prepare_dataset( - self, dataset: Union[str, DatasetDict, Dataset] - ) -> Union[Dataset, DatasetDict]: - """Preprocess a dataset, leave only needed columns, down-sample - - :param dataset: dataset from huggingface. It can be one of the following: - a DatasetDict (containing multiple splits) or a single dataset split (e.g., Dataset) - :return: Return clean and preprocessed dataset, that can be used in the transformer-ranker - """ - # Load huggingface dataset - if isinstance(dataset, str): - dataset = datasets.load_dataset(dataset, trust_remote_code=True) - - if not isinstance(dataset, (DatasetDict, Dataset)): - raise ValueError( - "The dataset must be an instance of either DatasetDict (for multiple splits) " - "or Dataset (for a single split) to be preprocessed." - ) + self.task_category = task_category - if self.merge_data_splits and isinstance(dataset, DatasetDict): - dataset = self._merge_data_splits(dataset) + def texts(self) -> list[str]: + """Gather all texts from the text column.""" + return self[self.text_column] - # Find text and label columns - text_column, label_column, label_type = self._find_text_and_label_columns( - dataset, self.text_column, self.label_column - ) - - # Find task category based on label type - if not self.task_type: - task_type = self._find_task_type(label_column, label_type) + def labels(self) -> torch.Tensor: + """Prepare labels as tensors.""" + if self.task_category == TaskCategory.TOKEN_CLASSIFICATION: + labels = [word_label for labels in self[self.label_column] for word_label in labels] else: - task_type = self.task_type - - if self.remove_empty_sentences: - dataset = self._remove_empty_rows( - dataset, - text_column, - label_column, - is_regression=task_type == "text regression" - ) - - if self.dataset_downsample: - dataset = self._downsample(dataset, ratio=self.dataset_downsample) + labels = self[self.label_column] + return torch.tensor(labels) - # Pre-tokenize sentences if pre-tokenizer is given - if not task_type == "token classification" and self.pre_tokenizer: - dataset = self._tokenize(dataset, self.pre_tokenizer, text_column) - # Concatenate text columns for text-pair tasks +@dataclass +class DatasetCleaner: + dataset_downsample: Optional[float] = None + text_column: Optional[str] = None + text_pair_column: Optional[str] = None + label_column: Optional[str] = None + label_map: Optional[dict] = None + task_type: Optional[TaskCategory] = None + cleanup_rows: bool = True + convert_bio_encoding: bool = True + tokenize: bool = False + + def prepare_dataset(self, dataset: Union[str, Dataset, DatasetDict]) -> PreprocessedDataset: + """Clean and verify the dataset by finding text and label fields, task type, + removing invalid entries, mapping labels, and down-sampling.""" + + # Verify a dataset + if not isinstance(dataset, (Dataset, DatasetDict)): + raise ValueError(f"Unsupported dataset type: {type(dataset)}") + + # Merge splits into one + dataset = self.merge_dataset_splits(dataset) + + # Search for text and label columns + text_column = self.text_column if self.text_column \ + else self.find_column("Text column", dataset) + label_column = self.label_column if self.label_column \ + else self.find_column("Label column", dataset) + + # Concat columns for text pairs if self.text_pair_column: - dataset, text_column = self._merge_textpairs( - dataset, text_column, self.text_pair_column - ) + dataset = self.merge_text_pairs(text_column, self.text_pair_column, dataset) + text_column = f"{text_column}+{self.text_pair_column}" - # Convert string labels to integers - if isinstance(label_type, str): - dataset, self.label_map = self._make_labels_categorical(dataset, label_column) + # Assign task category + task_category = self.task_type if self.task_type \ + else self.find_task_category(label_column, dataset) - # Try to find label map in the dataset - if not self.label_map: - self.label_map = self._create_label_map(dataset, label_column) + # Remove unused columns + dataset = dataset.select_columns([text_column, label_column]) - # Remove BIO prefixes for ner or chunking tasks - if task_type == "token classification" and self.change_bio_encoding: - dataset, self.label_map = self._change_bio_encoding( - dataset, label_column, self.label_map - ) + if self.dataset_downsample: + dataset = self.downsample(self.dataset_downsample, dataset) - # Keep only text and label columns - keep_columns = {text_column, self.text_pair_column, label_column} - {None} - columns_to_remove = list(set(dataset.column_names) - keep_columns) - dataset = dataset.remove_columns(columns_to_remove) + # Remove empty sentences and unsupported labels + if self.cleanup_rows: + dataset = self.remove_empty_rows(text_column, label_column, dataset) - # Set updated attributes and log them - self.text_column = text_column - self.label_column = label_column - self.task_type = task_type - self.dataset_size = len(dataset) - self.log_dataset_info() + # Optional tokenization if texts are not already tokenized + if self.tokenize and isinstance(dataset[text_column][0], str): + dataset = self.whitespace_tokenize(text_column, dataset) - return dataset + # Create the label map + label_map = self.label_map if self.label_map \ + else self.create_label_map(label_column, dataset) - def prepare_labels(self, dataset: Dataset) -> torch.Tensor: - """Prepare labels as tensors. - Flatten labels if they contain lists (for token classification)""" - labels = dataset[self.label_column] - labels = ( - [item for sublist in labels for item in sublist] - if isinstance(labels[0], list) - else labels + if isinstance(dataset[label_column][0], str): + dataset = self.make_labels_categorical(label_column, label_map, dataset) + + if self.convert_bio_encoding and task_category == TaskCategory.TOKEN_CLASSIFICATION: + dataset, label_map = self.remove_bio_encoding(dataset, label_column, label_map) + + self.log_dataset_info( + text_column, label_column, label_map, task_category, + self.dataset_downsample, dataset_size=len(dataset) + ) + + dataset = PreprocessedDataset( + dataset=dataset, + text_column=text_column, + label_column=label_column, + task_category=task_category, ) - return torch.tensor(labels) - def prepare_sentences(self, dataset: Dataset) -> list[str]: - """Gather sentences in the text column.""" - return dataset[self.text_column] + return dataset @staticmethod - def _downsample(dataset: Dataset, ratio: float) -> Dataset: - """Reduce the dataset to a chosen ratio.""" - dataset = dataset.shuffle(seed=42).select(range(int(len(dataset) * ratio))) + def merge_dataset_splits(dataset: Union[str, Dataset, DatasetDict]) -> Dataset: + if isinstance(dataset, DatasetDict): + dataset = datasets.concatenate_datasets(list(dataset.values())) return dataset @staticmethod - def _find_text_and_label_columns( - dataset: Dataset, text_column: Optional[str] = None, label_column: Optional[str] = None - ) -> tuple[str, str, Type]: - """Find text and label columns in hf datasets based on common keywords""" - text_columns = [ - "text", "sentence", "token", "tweet", "document", "paragraph", "description", - "comment", "utterance", "question", "story", "context", "passage", - ] - - label_columns = [ - "label", "ner_tag", "named_entities", "entities", "tag", "target", "category", - "class", "sentiment", "polarity", "emotion", "rating", "stance", - ] - - column_names = dataset.column_names - if not text_column: - # Iterate over keywords and check if it exists in the dataset - text_column = next( - (col for keyword in text_columns for col in column_names if keyword in col), None - ) - if not label_column: - label_column = next( - (col for keyword in label_columns for col in column_names if keyword in col), None - ) + def find_column(column_role: str, dataset: Dataset) -> str: + """Find text and label columns using common keywords.""" + common_names: dict = { + 'Text column': [ + "text", "sentence", "token", "tweet", "document", "paragraph", "description", + "comment", "utterance", "question", "story", "context", "passage", + ], + "Label column": [ + "label", "ner_tag", "named_entities", "entities", "tag", "target", "category", + "class", "sentiment", "polarity", "emotion", "rating", "stance", + ] + } - if not text_column or not label_column: - missing = "text" if not text_column else "label" - raise KeyError( - f'Can not determine the {missing} column. Specify {missing}_column="..." ' - f"from available columns: {column_names}." + columns = dataset.column_names + found_column = next( + (col for keyword in common_names[column_role] for col in columns if keyword in col), + None + ) + if found_column is None: + raise ValueError( + f"{column_role} not found in dataset: {dataset.column_names}. " + f"Specify it manually text_column: str = ..." ) - label_type = type(dataset[label_column][0]) - return text_column, label_column, label_type + return found_column @staticmethod - def _merge_textpairs( - dataset: Dataset, text_column: str, text_pair_column: str - ) -> tuple[Dataset, str]: + def merge_text_pairs(text_column: str, text_pair_column: str, dataset: Dataset) -> Dataset: """Concatenate text pairs into a single text using separator token""" - new_text_column_name = text_column + "+" + text_pair_column - if text_pair_column not in dataset.column_names: raise ValueError( f"Text pair column name '{text_pair_column}' can not be found in the dataset. " @@ -208,98 +175,86 @@ def merge_texts(dataset_entry: dict[str, str]) -> dict[str, str]: dataset_entry[text_column] = ( dataset_entry[text_column] + " [SEP] " + dataset_entry[text_pair_column] ) - dataset_entry[new_text_column_name] = dataset_entry.pop(text_column) return dataset_entry dataset = dataset.map(merge_texts, num_proc=None, desc="Merging text pair columns") - return dataset, new_text_column_name + new_text_column_name = text_column + "+" + text_pair_column + dataset = dataset.rename_column(text_column, new_text_column_name) + return dataset @staticmethod - def _find_task_type(label_column: str, label_type: type) -> str: - """Determine the task type based on the label column's data type.""" - label_type_to_task_type = { - int: "text classification", # text classification labels can be integers - str: "text classification", # or strings e.g. "positive" - list: "token classification", # token-level tasks have a list of labels - float: "text regression", # regression tasks have continuous values + def find_task_category(label_column: str, dataset: Dataset) -> TaskCategory: + """Determine task category based on the label column's data type.""" + label_to_task_type = { + int: TaskCategory.TEXT_CLASSIFICATION, # text classification labels can be integers + str: TaskCategory.TEXT_CLASSIFICATION, # or strings e.g. "positive" + list: TaskCategory.TOKEN_CLASSIFICATION, # token-level tasks have a list of labels + float: TaskCategory.TEXT_REGRESSION, # regression tasks have floats } - for key, task_type in label_type_to_task_type.items(): + label_type = type(dataset[label_column][0]) + + for key, task_type in label_to_task_type.items(): if issubclass(label_type, key): return task_type raise ValueError( - f"Cannot determine the task type for the label column '{label_column}'. " - f"Label types are {list(label_type_to_task_type.keys())}, but got {label_type}." + f"Cannot determine task category for the label column '{label_column}'. " + f"Label types are {list(label_to_task_type.keys())}, but got {label_type}." ) @staticmethod - def _tokenize(dataset: Dataset, pre_tokenizer: Whitespace, text_column: str) -> Dataset: - """Tokenize a dataset using hf pre-tokenizer (e.g. Whitespace)""" - - def pre_tokenize(example): - encoding = pre_tokenizer.pre_tokenize_str(example[text_column]) - example[text_column] = [token for token, _ in encoding] - return example - - dataset = dataset.map(pre_tokenize, num_proc=None, desc="Whitespace pre-tokenization") - return dataset - - @staticmethod - def _merge_data_splits(dataset: DatasetDict) -> Dataset: - """Merge DatasetDict into a single dataset.""" - return datasets.concatenate_datasets(list(dataset.values())) - - @staticmethod - def _remove_empty_rows(dataset: Dataset, text_column: str, label_column: str, is_regression: bool) -> Dataset: - """Filter out entries with empty or noisy text or labels.""" - + def remove_empty_rows( + text_column: str, label_column: str, dataset: Dataset + ) -> Dataset: + """Filter out entries with empty or noisy texts/labels.""" def is_valid_entry(sample) -> bool: text, label = sample[text_column], sample[label_column] - # Check if text is non-empty - if not text or not label: + # Remove empty entries + if not text or label is None: return False if not isinstance(text, list): text = [text] - # check the text does not contain emoji variation character '\uFE0F' - _BAD_CHARACTERS = "\uFE0F" - + # Remove sentences with characters unsupported by most tokenizers + _BAD_CHARACTERS = "\uFE0F" # emoji variation symbol '\uFE0F', etc. if any(c in t for t in text for c in _BAD_CHARACTERS): return False - if not is_regression: - # Check that the labels are non-negative - if not isinstance(label, list): - label = [label] + if not isinstance(label, list): + label = [label] - if any(word_label < 0 for word_label in label): - return False + # Remove negative labels from classification datasets + if any(isinstance(word_label, int) and word_label < 0 for word_label in label): + return False return True - return dataset.filter(is_valid_entry, desc="Removing empty rows") + dataset = dataset.filter(is_valid_entry, desc="Removing empty rows") + dataset = dataset.flatten_indices() + return dataset @staticmethod - def _make_labels_categorical( - dataset: Dataset, label_column: str - ) -> tuple[Dataset, dict[str, int]]: - """Convert string labels to integers""" - unique_labels = sorted(set(dataset[label_column])) - label_map = {label: idx for idx, label in enumerate(unique_labels)} + def make_labels_categorical(label_column, label_map, dataset) -> Dataset: + """Converts string labels to integers using a label map""" + def convert_label(label): + """Convert a label (string or list of strings) to its integer representation.""" + if isinstance(label, list): + return [label_map[word_label] for word_label in label] + return label_map[label] - def map_labels(dataset_entry): - dataset_entry[label_column] = label_map[dataset_entry[label_column]] - return dataset_entry + dataset = dataset.map( + lambda x: {"label": convert_label(x[label_column])}, + desc="Converting labels to categorical" + ) - dataset = dataset.map(map_labels, num_proc=None, desc="Mapping string labels to integers") - return dataset, label_map + return dataset @staticmethod - def _create_label_map(dataset: Dataset, label_column: str) -> dict[str, int]: - """Try to find feature names in a hf dataset.""" + def create_label_map(label_column: str, dataset: Dataset) -> dict[str, int]: + """Find feature names to create a label map in a hf dataset.""" label_names = getattr( getattr(dataset.features[label_column], "feature", None), "names", None ) or getattr(dataset.features[label_column], "names", None) @@ -314,15 +269,20 @@ def _create_label_map(dataset: Dataset, label_column: str) -> dict[str, int]: } ) - return {label: idx for idx, label in enumerate(label_names)} + label2id = {label: idx for idx, label in enumerate(label_names)} + return label2id @staticmethod - def _change_bio_encoding( + def downsample(ratio: float, dataset: Dataset) -> Dataset: + """Reduce the dataset to a chosen ratio.""" + dataset = dataset.shuffle(seed=42).select(range(int(len(dataset) * ratio))) + return dataset.flatten_indices() + + @staticmethod + def remove_bio_encoding( dataset: Dataset, label_column: str, label_map: dict[str, int] ) -> tuple[Dataset, dict[str, int]]: - """Remove BIO prefixes from NER labels, update the dataset, and create a new label map.""" - - # Get unique labels without BIO prefixes and create new label map + """Remove BIO prefixes for NER labels and create a new label map.""" unique_labels = set(label.split("-")[-1] for label in label_map) new_label_map = {label: idx for idx, label in enumerate(unique_labels)} @@ -334,23 +294,48 @@ def _change_bio_encoding( lambda sample: { label_column: [reverse_map[old_idx] for old_idx in sample[label_column]] }, - desc="Removing BIO prefixes", + desc="Removing BIO encoding", ) # Check if label map was changed if label_map == new_label_map: logger.warning( - "Could not remove BIO prefixes for this tagging dataset. " - "Please add the label map as parameter label_map: dict[str, int] = ... manually." + "Could not remove BIO prefixes for this tagging dataset. Please add the correct " + "label map as parameter label_map: dict[str, int] = ... manually." ) return dataset, new_label_map - def log_dataset_info(self) -> None: - """Log information about dataset""" - logger.info(f"Texts and labels: {self.text_column}, {self.label_column}") - logger.info(f"Label map: {self.label_map}") - is_downsampled = self.dataset_downsample and self.dataset_downsample < 1.0 - downsample_info = f"(down-sampled to {self.dataset_downsample})" if is_downsampled else "" - logger.info(f"Dataset size: {self.dataset_size} texts {downsample_info}") - logger.info(f"Task category: {self.task_type}") + @staticmethod + def whitespace_tokenize(text_column: str, dataset: Dataset) -> Dataset: + """Tokenize using Whitespace""" + tokenizer = Whitespace() + + def pre_tokenize(example): + encoding = tokenizer.pre_tokenize_str(example[text_column]) + example[text_column] = [token for token, _ in encoding] + return example + + dataset = dataset.map(pre_tokenize, num_proc=None, desc="Whitespace pre-tokenization") + return dataset + + @staticmethod + def log_dataset_info( + text_column, label_column, label_map, task_category, downsample_ratio, dataset_size + ) -> None: + """Log information about preprocessed dataset""" + # Basic dataset configuration + logger.info( + f"Dataset Info - Text Column: {text_column}, Label Column: {label_column}, " + f"Task Category: {task_category}, Dataset Size: {dataset_size} texts" + ) + + # Show the down-sampled size + if downsample_ratio and downsample_ratio < 1.0: + logger.info( + f"Dataset has been downsampled to {int(downsample_ratio * 100)}% of original size." + ) + + # Log the label map + if task_category != TaskCategory.TEXT_REGRESSION: + logger.info(f"Label Map: {label_map}") diff --git a/transformer_ranker/ranker.py b/transformer_ranker/ranker.py index fc672fd..a39fc7f 100644 --- a/transformer_ranker/ranker.py +++ b/transformer_ranker/ranker.py @@ -1,5 +1,5 @@ import logging -from typing import Optional, Union +from typing import List, Optional, Union import torch from datasets.dataset_dict import Dataset, DatasetDict @@ -23,28 +23,27 @@ def __init__( **kwargs, ): """ - Rank language models for various NLP tasks. Extract embeddings and evaluate - their suitability for a dataset using metrics like hscore or logme. - Embeddings can be averaged across all layers or selected from the best-suited layer. + Rank language models for different NLP tasks. Embed a part of the dataset and + estimate embedding suitability with transferability metrics like hscore or logme. + Embeddings can be averaged across all layers or selected from the best-performing layer. :param dataset: a dataset from huggingface, containing texts and label columns. :param dataset_downsample: a fraction to which the dataset should be reduced. :param kwargs: Additional dataset-specific parameters for data cleaning. """ - self.data_cleaner = DatasetCleaner( + # Clean the original dataset and keep only needed columns + datacleaner = DatasetCleaner( dataset_downsample=dataset_downsample, text_column=text_column, label_column=label_column, **kwargs, ) - # Prepare dataset, identify task category - self.dataset = self.data_cleaner.prepare_dataset(dataset) - self.task_type = self.data_cleaner.task_type + self.dataset = datacleaner.prepare_dataset(dataset) def run( self, - models: list[Union[str, torch.nn.Module]], + models: List[Union[str, torch.nn.Module]], batch_size: int = 32, estimator: str = "hscore", layer_aggregator: str = "layermean", @@ -54,30 +53,25 @@ def run( **kwargs, ): """ - Load models, get embeddings, score them, and rank results. + Load models, get embeddings, score, and rank results. :param models: A list of model names string identifiers :param batch_size: The number of samples to process in each batch, defaults to 32. - :param estimator: Transferability metric: 'hscore', 'logme', 'knn' - :param layer_aggregator: Which layers to use 'layermean', 'bestlayer' - :param sentence_pooling: Pool words into a sentence embedding for text classification. - :param device: Device for language models ('cpu', 'cuda', 'cuda:2') - :param gpu_estimation: If to score embeddings on the same device (defaults to true) + :param estimator: Transferability metric (e.g., 'hscore', 'logme', 'knn'). + :param layer_aggregator: Which layer to select (e.g., 'layermean', 'bestlayer'). + :param sentence_pooling: Embedder parameter for pooling words into a sentence embedding for + text classification tasks. Defaults to "mean" to average of all words. + :param device: Device for embedding, defaults to GPU if available ('cpu', 'cuda', 'cuda:2'). + :param gpu_estimation: Store and score embeddings on GPU for speedup. :param kwargs: Additional parameters for the embedder class (e.g. subword pooling) :return: Returns the sorted dictionary of model names and their scores """ self._confirm_ranker_setup(estimator=estimator, layer_aggregator=layer_aggregator) - # Set device for language model embeddings and log it - device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu")) - logger.info(f"Running on {device}") - # Load all transformers into hf cache self._preload_transformers(models, device) - # Prepare texts and labels from the dataset - texts = self.data_cleaner.prepare_sentences(self.dataset) - labels = self.data_cleaner.prepare_labels(self.dataset) + labels = self.dataset.labels() ranking_results = Result(metric=estimator) @@ -89,7 +83,7 @@ def run( # Sentence pooling is only applied for text classification tasks effective_sentence_pooling = ( - None if self.task_type == "token classification" else sentence_pooling + None if self.dataset.task_category == "token classification" else sentence_pooling ) embedder = Embedder( @@ -102,14 +96,14 @@ def run( ) embeddings = embedder.embed( - sentences=texts, + self.dataset.texts(), batch_size=batch_size, show_loading_bar=True, move_embeddings_to_cpu=not gpu_estimation, ) # Single list of embeddings for sequence tagging tasks - if self.task_type == "token classification": + if self.dataset.task_category == "token classification": embeddings = [word for sentence in embeddings for word in sentence] model_name = embedder.model_name @@ -150,7 +144,7 @@ def run( zip(embedded_layer_ids, layer_scores) ) - # Layer average gives one score, bestlayer uses max of scores + # Aggregate layer scores final_score = max(layer_scores) if layer_aggregator == "bestlayer" else layer_scores[0] ranking_results.add_score(model_name, final_score) @@ -158,7 +152,7 @@ def run( result_log = f"{model_name} estimation: {final_score} ({ranking_results.metric})" if layer_aggregator == "bestlayer": - result_log += f", layer scores: {ranking_results.layerwise_scores[model_name]}" + result_log += f", layerwise scores: {ranking_results.layerwise_scores[model_name]}" logger.info(result_log) @@ -166,8 +160,7 @@ def run( @staticmethod def _preload_transformers( - models: list[Union[str, torch.nn.Module]], - device: torch.device, + models: List[Union[str, torch.nn.Module]], device: Optional[str] = None ) -> None: """Loads all models into HuggingFace cache""" cached_models, download_models = [], [] @@ -201,14 +194,14 @@ def _confirm_ranker_setup(self, estimator, layer_aggregator) -> None: ) valid_task_types = ["text classification", "token classification", "text regression"] - if self.task_type not in valid_task_types: + if self.dataset.task_category not in valid_task_types: raise ValueError( "Unable to determine task type of the dataset. Please specify it as a parameter: " 'task_type= "text classification", "token classification", or ' '"text regression"' ) - if self.task_type == "text regression" and estimator == "hscore": + if self.dataset.task_category == "text regression" and estimator == "hscore": supported_estimators = [est for est in valid_estimators if est != "hscore"] raise ValueError( f'"{estimator}" does not support text regression. ' @@ -218,8 +211,8 @@ def _confirm_ranker_setup(self, estimator, layer_aggregator) -> None: def _estimate_score(self, estimator, embeddings: torch.Tensor, labels: torch.Tensor) -> float: """Use an estimator to score a transformer""" estimator_classes = { - "knn": KNN(k=3, regression=(self.task_type == "text regression")), - "logme": LogME(regression=(self.task_type == "text regression")), + "knn": KNN(k=3, regression=(self.dataset.task_category == "text regression")), + "logme": LogME(regression=(self.dataset.task_category == "text regression")), "hscore": HScore(), }