Skip to content

Commit

Permalink
Add abstract class for transferability metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
lukasgarbas committed Dec 1, 2024
1 parent 7edf0e4 commit 654ea83
Show file tree
Hide file tree
Showing 7 changed files with 113 additions and 95 deletions.
6 changes: 3 additions & 3 deletions tests/estimators/test_knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
import torch

from transformer_ranker.estimators import KNN
from transformer_ranker.estimators import NearestNeighbors


def sample_data(
Expand Down Expand Up @@ -40,7 +40,7 @@ def sample_data(
@pytest.mark.parametrize("k,dim", [(6, 1024), (10, 100), (100, 256), (1024, 16)])
def test_knn_on_constructed_data(k, dim):
features, labels, expected_accuracy = sample_data(k=k, dim=dim)
estimator = KNN(k)
estimator = NearestNeighbors(k=k)

accuracy = estimator.fit(features, labels)

Expand All @@ -57,6 +57,6 @@ def test_knn_on_constructed_data(k, dim):
],
)
def test_knn_on_iris(iris_dataset, k, expected_accuracy):
e = KNN(k)
e = NearestNeighbors(k=k)
score = e.fit(iris_dataset["data"], iris_dataset["target"])
assert score == pytest.approx(expected_accuracy)
4 changes: 3 additions & 1 deletion transformer_ranker/estimators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from .hscore import HScore
from .logme import LogME
from .nearestneighbors import KNN
from .nearestneighbors import NearestNeighbors

__all__ = ["HScore", "LogME", "NearestNeighbors"]
20 changes: 20 additions & 0 deletions transformer_ranker/estimators/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from abc import ABC, abstractmethod
from typing import Optional

import torch


class Estimator(ABC):
"""Abstract base class for transferability metrics."""
def __init__(self, regression: bool, **kwargs):
self.regression: bool = regression
self.score: Optional[float] = None

@abstractmethod
def fit(self, *, embeddings: torch.tensor, labels: torch.tensor, **kwargs) -> float:
"""Compute score given embeddings and labels.
:param embeddings: Embedding tensor (num_samples, num_features)
:param labels: label tensor (num_samples,)
"""
pass
23 changes: 15 additions & 8 deletions transformer_ranker/estimators/hscore.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,30 @@
import warnings

import torch

from .base import Estimator


class HScore:
def __init__(self):
class HScore(Estimator):
def __init__(self, regression: bool = False):
"""
Regularized H-Score estimator.
Original H-score paper: https://arxiv.org/abs/2212.10082
Paper: https://arxiv.org/abs/2212.10082
Shrinkage-based (regularized) H-Score: https://openreview.net/pdf?id=iz_Wwmfquno
"""
self.score = None
if regression:
warnings.warn("HScore is not suitable for regression tasks.", UserWarning)

super().__init__(regression=regression)

def fit(self, embeddings: torch.Tensor, labels: torch.Tensor) -> float:
"""
H-score intuition: Higher variance between embeddings of different classes
H-score intuition: higher variance between embeddings of different classes
(mean vectors for each class) and lower feature redundancy (i.e. inverse of the covariance
matrix for all data points) lead to better transferability.
:param embeddings: Embedding matrix of shape (num_samples, hidden_size)
:param labels: Label vector of shape (num_samples,)
:param embeddings: Embedding tensor (num_samples, hidden_size)
:param labels: Label tensor (num_samples,)
:return: H-score, where higher is better.
"""
# Center all embeddings
Expand All @@ -26,7 +33,7 @@ def fit(self, embeddings: torch.Tensor, labels: torch.Tensor) -> float:

# Number of samples, hidden size (i.e. embedding length), number of classes
num_samples, hidden_size = embeddings.size()
classes, class_counts = torch.unique(labels, return_counts=True)
classes, _ = torch.unique(labels, return_counts=True)
num_classes = len(classes)

# Feature covariance matrix (hidden_size x hidden_size)
Expand Down
16 changes: 7 additions & 9 deletions transformer_ranker/estimators/logme.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
from typing import Optional
import torch

from .base import Estimator

class LogME:

class LogME(Estimator):
def __init__(self, regression: bool = False):
"""
LogME (Log of Maximum Evidence) estimator.
Paper: https://arxiv.org/abs/2102.11005
:param regression: Boolean flag if the task is regression.
"""
self.regression = regression
self.score: Optional[float] = None
super().__init__(regression=regression)

def fit(
self,
Expand All @@ -27,8 +25,8 @@ def fit(
the prior (alpha) and likelihood (beta), projecting the target labels onto the singular
vectors of the feature matrix.
:param embeddings: Embedding matrix of shape (num_samples, hidden_dim)
:param labels: Label vector of shape (num_samples,)
:param embeddings: Embedding tensor (num_samples, hidden_dim)
:param labels: Label tensor (num_samples,)
:param initial_alpha: Initial precision of the prior (controls the regularization strength)
:param initial_beta: Initial precision of the likelihood (controls the noise in the data)
:param tol: Tolerance for the optimization convergence
Expand All @@ -44,7 +42,7 @@ def fit(

# Get the number of samples, number of classes, and the hidden size
num_samples, hidden_size = embeddings.shape
class_names, counts = torch.unique(labels, return_counts=True)
class_names, _ = torch.unique(labels, return_counts=True)
num_classes = labels.shape[1] if self.regression else len(class_names)

# SVD on the features
Expand Down
44 changes: 31 additions & 13 deletions transformer_ranker/estimators/nearestneighbors.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,61 @@
from typing import Union, Optional

from typing import Optional, Union

import torch
from torchmetrics.classification import BinaryF1Score, MulticlassF1Score
from torch.nn.functional import cosine_similarity

from .base import Estimator

class KNN:

class NearestNeighbors(Estimator):
def __init__(
self,
k: int = 3,
regression: bool = False,
k: int = 3,
):
"""
K-Nearest Neighbors estimator.
:param k: Number of nearest neighbors to consider.
:param regression: Boolean flag if the task is regression.
"""
self.k = k
self.regression = regression
self.score: Optional[float] = None
super().__init__(regression=regression)

self.k = k # number of neighbors
self.distance_metrics = {
'euclidean': lambda x, y: torch.cdist(x, y, p=2),
'cosine': lambda x, y: 1 - cosine_similarity(x[:, None, :], y[None, :, :], dim=-1)
}

def fit(self, embeddings: torch.Tensor, labels: torch.Tensor, batch_size: int = 1024) -> float:
def fit(
self,
embeddings: torch.Tensor,
labels: torch.Tensor,
batch_size: int = 1024,
distance_metric: str = 'euclidean',
) -> float:
"""
Estimate embedding suitability for classification or regression using nearest neighbors
Evaluate embeddings using kNN. Distance and topk computations are done in batches.
:param embeddings: Embedding matrix of shape (n_samples, hidden_size)
:param labels: Label vector of shape (n_samples,)
:param embeddings: Embedding tensor (n_samples, hidden_size)
:param labels: Label tensor (n_samples,)
:param batch_size: Batch size for distance and top-k computation in chunks
:return: Score (F1 score for classification or Pearson correlation for regression)
:param distance_metric: Metric to use for distance computation 'euclidean', 'cosine'
:return: F1-micro score (for classification) or Pearson correlation (for regression)
"""
num_samples = embeddings.size(0)
num_classes = len(torch.unique(labels))
knn_indices = torch.zeros((num_samples, self.k), dtype=torch.long, device=embeddings.device)

distance_func = self.distance_metrics.get(distance_metric)

for start in range(0, num_samples, batch_size):
end = min(start + batch_size, num_samples)
batch_features = embeddings[start:end]

# Euclidean distances between the batch and all other features
dists = torch.cdist(batch_features, embeddings, p=2)
# Distances between the batch and all other features
dists = distance_func(batch_features, embeddings)

# Exclude self-distances by setting diagonal to a large number
diag_indices = torch.arange(start, end, device=embeddings.device)
Expand Down
Loading

0 comments on commit 654ea83

Please sign in to comment.