Skip to content

Commit

Permalink
Add typing
Browse files Browse the repository at this point in the history
  • Loading branch information
IgorSusmelj committed Nov 19, 2023
1 parent 9f2ed93 commit 8f43d6b
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 21 deletions.
17 changes: 10 additions & 7 deletions lightly/utils/benchmarking/linear_classifier.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Tuple
from typing import Any, Dict, List, Tuple

from pytorch_lightning import LightningModule
from torch import Tensor
Expand Down Expand Up @@ -94,17 +94,20 @@ def __init__(

def forward(self, images: Tensor) -> Tensor:
features = self.model.forward(images).flatten(start_dim=1)
return self.classification_head(features)
output: Tensor = self.classification_head(features)
return output

def shared_step(self, batch, batch_idx) -> Tuple[Tensor, Dict[int, Tensor]]:
def shared_step(
self, batch: Tuple[Tensor, Tensor], batch_idx: int
) -> Tuple[Tensor, Dict[int, Tensor]]:
images, targets = batch[0], batch[1]
predictions = self.forward(images)
loss = self.criterion(predictions, targets)
_, predicted_labels = predictions.topk(max(self.topk))
topk = mean_topk_accuracy(predicted_labels, targets, k=self.topk)
return loss, topk

def training_step(self, batch, batch_idx) -> Tensor:
def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
loss, topk = self.shared_step(batch=batch, batch_idx=batch_idx)
batch_size = len(batch[1])
log_dict = {f"train_top{k}": acc for k, acc in topk.items()}
Expand All @@ -114,15 +117,15 @@ def training_step(self, batch, batch_idx) -> Tensor:
self.log_dict(log_dict, sync_dist=True, batch_size=batch_size)
return loss

def validation_step(self, batch, batch_idx) -> Tensor:
def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
loss, topk = self.shared_step(batch=batch, batch_idx=batch_idx)
batch_size = len(batch[1])
log_dict = {f"val_top{k}": acc for k, acc in topk.items()}
self.log("val_loss", loss, prog_bar=True, sync_dist=True, batch_size=batch_size)
self.log_dict(log_dict, prog_bar=True, sync_dist=True, batch_size=batch_size)
return loss

def configure_optimizers(self):
def configure_optimizers(self) -> Tuple[List[Any], List[Dict[str, Any]]]:
parameters = list(self.classification_head.parameters())
if not self.freeze_model:
parameters += self.model.parameters()
Expand All @@ -136,7 +139,7 @@ def configure_optimizers(self):
"scheduler": CosineWarmupScheduler(
optimizer=optimizer,
warmup_epochs=0,
max_epochs=self.trainer.estimated_stepping_batches,
max_epochs=int(self.trainer.estimated_stepping_batches),
),
"interval": "step",
}
Expand Down
2 changes: 1 addition & 1 deletion lightly/utils/benchmarking/metric_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class MetricCallback(Callback):
>>> max_val_acc = max(metric_callback.val_metrics["val_acc"])
"""

def __init__(self):
def __init__(self) -> None:
super().__init__()
self.train_metrics: Dict[str, List[float]] = {}
self.val_metrics: Dict[str, List[float]] = {}
Expand Down
15 changes: 9 additions & 6 deletions lightly/utils/bounding_box.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
""" Bounding Box Utils """

from __future__ import annotations

# Copyright (c) 2020. Lightly AG and its affiliates.
# All Rights Reserved
Expand Down Expand Up @@ -43,7 +44,7 @@ def __init__(

if clip_values:

def clip_to_0_1(value):
def clip_to_0_1(value: float) -> float:
return min(1, max(0, value))

x0 = clip_to_0_1(x0)
Expand Down Expand Up @@ -75,7 +76,7 @@ def clip_to_0_1(value):
self.y1 = y1

@classmethod
def from_x_y_w_h(cls, x: float, y: float, w: float, h: float):
def from_x_y_w_h(cls, x: float, y: float, w: float, h: float) -> BoundingBox:
"""Helper to convert from bounding box format with width and height.
Examples:
Expand All @@ -85,7 +86,9 @@ def from_x_y_w_h(cls, x: float, y: float, w: float, h: float):
return cls(x, y, x + w, y + h)

@classmethod
def from_yolo_label(cls, x_center: float, y_center: float, w: float, h: float):
def from_yolo_label(
cls, x_center: float, y_center: float, w: float, h: float
) -> BoundingBox:
"""Helper to convert from yolo label format
x_center, y_center, w, h --> x0, y0, x1, y1
Expand All @@ -102,16 +105,16 @@ def from_yolo_label(cls, x_center: float, y_center: float, w: float, h: float):
)

@property
def width(self):
def width(self) -> float:
"""Returns the width of the bounding box relative to the image size."""
return self.x1 - self.x0

@property
def height(self):
def height(self) -> float:
"""Returns the height of the bounding box relative to the image size."""
return self.y1 - self.y0

@property
def area(self):
def area(self) -> float:
"""Returns the area of the bounding box relative to the area of the image."""
return self.width * self.height
28 changes: 21 additions & 7 deletions lightly/utils/embeddings_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@
# Copyright (c) 2020. Lightly AG and its affiliates.
# All Rights Reserved

from __future__ import annotations

from typing import Optional, Tuple

import numpy as np
from numpy.typing import NDArray


class PCA(object):
Expand All @@ -18,11 +23,11 @@ class PCA(object):

def __init__(self, n_components: int = 2, eps: float = 1e-10):
self.n_components = n_components
self.mean = None
self.w = None
self.mean: Optional[NDArray[np.float32]] = None
self.w: Optional[NDArray[np.float32]] = None
self.eps = eps

def fit(self, X: np.ndarray):
def fit(self, X: NDArray[np.float32]) -> PCA:
"""Fits PCA to data in X.
Args:
Expand All @@ -35,15 +40,15 @@ def fit(self, X: np.ndarray):
"""
X = X.astype(np.float32)
self.mean = X.mean(axis=0)
X = X - self.mean + self.eps
X = X - self.mean + self.eps # type: ignore
cov = np.cov(X.T) / X.shape[0]
v, w = np.linalg.eig(cov)
idx = v.argsort()[::-1]
v, w = v[idx], w[:, idx]
self.w = w
return self

def transform(self, X: np.ndarray):
def transform(self, X: NDArray[np.float32]) -> NDArray[np.float32]:
"""Uses PCA to transform data in X.
Args:
Expand All @@ -53,13 +58,22 @@ def transform(self, X: np.ndarray):
Returns:
Numpy array of n x p datapoints where p <= d.
Raises:
ValueError: If PCA was not fitted before.
"""
if self.mean is None or self.w is None:
raise ValueError("PCA not fitted yet. Call fit() before transform().")
X = X.astype(np.float32)
X = X - self.mean + self.eps
return X.dot(self.w)[:, : self.n_components]
transformed = X.dot(self.w)[:, : self.n_components]
return np.asarray(transformed, dtype=np.float32)


def fit_pca(embeddings: np.ndarray, n_components: int = 2, fraction: float = None):
def fit_pca(
embeddings: NDArray[np.float32],
n_components: int = 2,
fraction: Optional[float] = None,
) -> PCA:
"""Fits PCA to randomly selected subset of embeddings.
For large datasets, it can be unfeasible to perform PCA on the whole data.
Expand Down

0 comments on commit 8f43d6b

Please sign in to comment.