Skip to content

Commit

Permalink
Merge branch 'lightly-ai:master' into cleanup_lightly/models/modules
Browse files Browse the repository at this point in the history
  • Loading branch information
ayush22iitbhu authored Oct 24, 2024
2 parents 57bb534 + b6955fd commit 328daad
Show file tree
Hide file tree
Showing 8 changed files with 114 additions and 23 deletions.
10 changes: 6 additions & 4 deletions benchmarks/imagenet/resnet50/finetune_eval.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from pathlib import Path
from typing import Dict

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import DeviceStatsMonitor, LearningRateMonitor
Expand Down Expand Up @@ -47,7 +48,7 @@ def finetune_eval(
devices: int,
precision: str,
num_classes: int,
) -> None:
) -> Dict[str, float]:
"""Runs fine-tune evaluation on the given model.
Parameters follow SimCLR [0] settings.
Expand Down Expand Up @@ -131,7 +132,8 @@ def finetune_eval(
train_dataloaders=train_dataloader,
val_dataloaders=val_dataloader,
)
metrics_dict: Dict[str, float] = dict()
for metric in ["val_top1", "val_top5"]:
print_rank_zero(
f"max finetune {metric}: {max(metric_callback.val_metrics[metric])}"
)
print(f"max finetune {metric}: {max(metric_callback.val_metrics[metric])}")
metrics_dict[metric] = max(metric_callback.val_metrics[metric])
return metrics_dict
8 changes: 6 additions & 2 deletions benchmarks/imagenet/resnet50/knn_eval.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from pathlib import Path
from typing import Dict

import torch
from pytorch_lightning import LightningModule, Trainer
Expand All @@ -23,7 +24,7 @@ def knn_eval(
accelerator: str,
devices: int,
num_classes: int,
) -> None:
) -> Dict[str, float]:
"""Runs KNN evaluation on the given model.
Parameters follow InstDisc [0] settings.
Expand Down Expand Up @@ -89,5 +90,8 @@ def knn_eval(
train_dataloaders=train_dataloader,
val_dataloaders=val_dataloader,
)
metrics_dict: dict[str, float] = dict()
for metric in ["val_top1", "val_top5"]:
print_rank_zero(f"knn {metric}: {max(metric_callback.val_metrics[metric])}")
print(f"knn {metric}: {max(metric_callback.val_metrics[metric])}")
metrics_dict[metric] = max(metric_callback.val_metrics[metric])
return metrics_dict
10 changes: 6 additions & 4 deletions benchmarks/imagenet/resnet50/linear_eval.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from pathlib import Path
from typing import Dict

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import DeviceStatsMonitor, LearningRateMonitor
Expand All @@ -24,7 +25,7 @@ def linear_eval(
devices: int,
precision: str,
num_classes: int,
) -> None:
) -> Dict[str, float]:
"""Runs a linear evaluation on the given model.
Parameters follow SimCLR [0] settings.
Expand Down Expand Up @@ -108,7 +109,8 @@ def linear_eval(
train_dataloaders=train_dataloader,
val_dataloaders=val_dataloader,
)
metrics_dict: Dict[str, float] = dict()
for metric in ["val_top1", "val_top5"]:
print_rank_zero(
f"max linear {metric}: {max(metric_callback.val_metrics[metric])}"
)
print(f"max linear {metric}: {max(metric_callback.val_metrics[metric])}")
metrics_dict[metric] = max(metric_callback.val_metrics[metric])
return metrics_dict
46 changes: 41 additions & 5 deletions benchmarks/imagenet/resnet50/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from argparse import ArgumentParser
from datetime import datetime
from pathlib import Path
from typing import Sequence, Union
from typing import Dict, Sequence, Union

import barlowtwins
import byol
Expand Down Expand Up @@ -121,11 +121,11 @@ def main(
precision=precision,
ckpt_path=ckpt_path,
)

eval_metrics: Dict[str, Dict[str, float]] = Dict()
if skip_knn_eval:
print_rank_zero("Skipping KNN eval.")
else:
knn_eval.knn_eval(
eval_metrics["knn"] = knn_eval.knn_eval(
model=model,
num_classes=num_classes,
train_dir=train_dir,
Expand All @@ -140,7 +140,7 @@ def main(
if skip_linear_eval:
print_rank_zero("Skipping linear eval.")
else:
linear_eval.linear_eval(
eval_metrics["linear"] = linear_eval.linear_eval(
model=model,
num_classes=num_classes,
train_dir=train_dir,
Expand All @@ -156,7 +156,7 @@ def main(
if skip_finetune_eval:
print_rank_zero("Skipping fine-tune eval.")
else:
finetune_eval.finetune_eval(
eval_metrics["finetune"] = finetune_eval.finetune_eval(
model=model,
num_classes=num_classes,
train_dir=train_dir,
Expand All @@ -169,6 +169,10 @@ def main(
precision=precision,
)

if eval_metrics:
print(f"Results for {method}:")
print(eval_metrics_to_markdown(eval_metrics))


def pretrain(
model: LightningModule,
Expand Down Expand Up @@ -246,6 +250,38 @@ def pretrain(
print_rank_zero(f"max {metric}: {max(metric_callback.val_metrics[metric])}")


def eval_metrics_to_markdown(metrics: Dict[str, Dict[str, float]]) -> str:
EVAL_NAME_COLUMN_NAME = "Eval Name"
METRIC_COLUMN_NAME = "Metric Name"
VALUE_COLUMN_NAME = "Value"

eval_name_max_len = max(
len(eval_name) for eval_name in list(metrics.keys()) + [EVAL_NAME_COLUMN_NAME]
)
metric_name_max_len = max(
len(metric_name)
for metric_dict in metrics.values()
for metric_name in list(metric_dict.keys()) + [METRIC_COLUMN_NAME]
)
value_max_len = max(
len(metric_value)
for metric_dict in metrics.values()
for metric_value in list(f"{value:.2f}" for value in metric_dict.values())
+ [VALUE_COLUMN_NAME]
)

header = f"| {EVAL_NAME_COLUMN_NAME.ljust(eval_name_max_len)} | {METRIC_COLUMN_NAME.ljust(metric_name_max_len)} | {VALUE_COLUMN_NAME.ljust(value_max_len)} |"
separator = f"|:{'-' * (eval_name_max_len)}:|:{'-' * (metric_name_max_len)}:|:{'-' * (value_max_len)}:|"

lines = [header, separator] + [
f"| {eval_name.ljust(eval_name_max_len)} | {metric_name.ljust(metric_name_max_len)} | {f'{metric_value:.2f}'.ljust(value_max_len)} |"
for eval_name, metric_dict in metrics.items()
for metric_name, metric_value in metric_dict.items()
]

return "\n".join(lines)


if __name__ == "__main__":
args = parser.parse_args()
main(**vars(args))
5 changes: 4 additions & 1 deletion benchmarks/imagenet/vitb16/finetune_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def finetune_eval(
devices: int,
precision: str,
num_classes: int,
) -> None:
) -> Dict[str, float]:
"""Runs fine-tune evaluation on the given model.
Parameters follow MAE settings.
Expand Down Expand Up @@ -211,5 +211,8 @@ def finetune_eval(
train_dataloaders=train_dataloader,
val_dataloaders=val_dataloader,
)
metrics_dict: Dict[str, float] = dict()
for metric in ["val_top1", "val_top5"]:
print(f"max finetune {metric}: {max(metric_callback.val_metrics[metric])}")
metrics_dict[metric] = max(metric_callback.val_metrics[metric])
return metrics_dict
6 changes: 5 additions & 1 deletion benchmarks/imagenet/vitb16/knn_eval.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from pathlib import Path
from typing import Dict

import torch
from pytorch_lightning import LightningModule, Trainer
Expand All @@ -22,7 +23,7 @@ def knn_eval(
accelerator: str,
devices: int,
num_classes: int,
) -> None:
) -> Dict[str, float]:
print("Running KNN evaluation...")

# Setup training data.
Expand Down Expand Up @@ -76,5 +77,8 @@ def knn_eval(
train_dataloaders=train_dataloader,
val_dataloaders=val_dataloader,
)
metrics_dict: dict[str, float] = dict()
for metric in ["val_top1", "val_top5"]:
print(f"knn {metric}: {max(metric_callback.val_metrics[metric])}")
metrics_dict[metric] = max(metric_callback.val_metrics[metric])
return metrics_dict
7 changes: 5 additions & 2 deletions benchmarks/imagenet/vitb16/linear_eval.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import Tuple
from typing import Dict, Tuple

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import DeviceStatsMonitor, LearningRateMonitor
Expand Down Expand Up @@ -74,7 +74,7 @@ def linear_eval(
devices: int,
precision: str,
num_classes: int,
) -> None:
) -> Dict[str, float]:
"""Runs a linear evaluation on the given model.
Parameters follow MAE settings.
Expand Down Expand Up @@ -145,5 +145,8 @@ def linear_eval(
train_dataloaders=train_dataloader,
val_dataloaders=val_dataloader,
)
metrics_dict: Dict[str, float] = dict()
for metric in ["val_top1", "val_top5"]:
print(f"max linear {metric}: {max(metric_callback.val_metrics[metric])}")
metrics_dict[metric] = max(metric_callback.val_metrics[metric])
return metrics_dict
45 changes: 41 additions & 4 deletions benchmarks/imagenet/vitb16/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from argparse import ArgumentParser
from datetime import datetime
from pathlib import Path
from typing import Sequence, Union
from typing import Dict, Sequence, Union

import aim
import finetune_eval
Expand Down Expand Up @@ -100,10 +100,11 @@ def main(
strategy=strategy,
)

eval_metrics: Dict[str, Dict[str, float]] = Dict()
if skip_knn_eval:
print("Skipping KNN eval.")
else:
knn_eval.knn_eval(
eval_metrics["knn"] = knn_eval.knn_eval(
model=model,
num_classes=num_classes,
train_dir=train_dir,
Expand All @@ -118,7 +119,7 @@ def main(
if skip_linear_eval:
print("Skipping linear eval.")
else:
linear_eval.linear_eval(
eval_metrics["linear"] = linear_eval.linear_eval(
model=model,
num_classes=num_classes,
train_dir=train_dir,
Expand All @@ -134,7 +135,7 @@ def main(
if skip_finetune_eval:
print("Skipping fine-tune eval.")
else:
finetune_eval.finetune_eval(
eval_metrics["finetune"] = finetune_eval.finetune_eval(
model=model,
num_classes=num_classes,
train_dir=train_dir,
Expand All @@ -147,6 +148,10 @@ def main(
precision=precision,
)

if eval_metrics:
print(f"Results for {method}:")
print(eval_metrics_to_markdown(eval_metrics))


def pretrain(
model: LightningModule,
Expand Down Expand Up @@ -218,6 +223,38 @@ def pretrain(
print(f"max {metric}: {max(metric_callback.val_metrics[metric])}")


def eval_metrics_to_markdown(metrics: Dict[str, Dict[str, float]]) -> str:
EVAL_NAME_COLUMN_NAME = "Eval Name"
METRIC_COLUMN_NAME = "Metric Name"
VALUE_COLUMN_NAME = "Value"

eval_name_max_len = max(
len(eval_name) for eval_name in list(metrics.keys()) + [EVAL_NAME_COLUMN_NAME]
)
metric_name_max_len = max(
len(metric_name)
for metric_dict in metrics.values()
for metric_name in list(metric_dict.keys()) + [METRIC_COLUMN_NAME]
)
value_max_len = max(
len(metric_value)
for metric_dict in metrics.values()
for metric_value in list(f"{value:.2f}" for value in metric_dict.values())
+ [VALUE_COLUMN_NAME]
)

header = f"| {EVAL_NAME_COLUMN_NAME.ljust(eval_name_max_len)} | {METRIC_COLUMN_NAME.ljust(metric_name_max_len)} | {VALUE_COLUMN_NAME.ljust(value_max_len)} |"
separator = f"|:{'-' * (eval_name_max_len)}:|:{'-' * (metric_name_max_len)}:|:{'-' * (value_max_len)}:|"

lines = [header, separator] + [
f"| {eval_name.ljust(eval_name_max_len)} | {metric_name.ljust(metric_name_max_len)} | {f'{metric_value:.2f}'.ljust(value_max_len)} |"
for eval_name, metric_dict in metrics.items()
for metric_name, metric_value in metric_dict.items()
]

return "\n".join(lines)


if __name__ == "__main__":
args = parser.parse_args()
main(**vars(args))

0 comments on commit 328daad

Please sign in to comment.