From 79e28ab0afabfbeb1a889353ea7624be185336ff Mon Sep 17 00:00:00 2001 From: clemsgrs Date: Fri, 15 Mar 2024 12:11:20 +0100 Subject: [PATCH] fixed bug in knn evaluation: probabilities didnt sum to 1 --- dinov2/eval/knn.py | 179 ++++++++++++++------------- dinov2/eval/setup.py | 10 +- dinov2/eval/utils.py | 3 +- dinov2/inference/extract_features.py | 2 +- 4 files changed, 99 insertions(+), 95 deletions(-) diff --git a/dinov2/eval/knn.py b/dinov2/eval/knn.py index 42df119b4..eaddd4707 100644 --- a/dinov2/eval/knn.py +++ b/dinov2/eval/knn.py @@ -3,95 +3,52 @@ # This source code is licensed under the Apache License, Version 2.0 # found in the LICENSE file in the root directory of this source tree. +import os +import datetime import argparse from functools import partial import json import logging -import sys from pathlib import Path -from typing import List, Optional +from typing import Optional import torch from torch.nn.functional import one_hot, softmax import dinov2.distributed as distributed -from dinov2.data import SamplerType, make_data_loader +from dinov2.data import SamplerType, make_data_loader, make_dataset from dinov2.eval.metrics import AccuracyAveraging, build_metric -from dinov2.eval.setup import get_args_parser as get_setup_args_parser +from dinov2.utils.utils import initialize_wandb +from dinov2.utils.config import setup, write_config from dinov2.eval.setup import setup_and_build_model from dinov2.eval.utils import ModelWithNormalize, evaluate, extract_features +from dinov2.data.transforms import make_classification_eval_transform logger = logging.getLogger("dinov2") -def get_args_parser( - description: Optional[str] = None, - parents: Optional[List[argparse.ArgumentParser]] = None, - add_help: bool = True, -): - parents = parents or [] - setup_args_parser = get_setup_args_parser(parents=parents, add_help=False) - parents = [setup_args_parser] - parser = argparse.ArgumentParser( - description=description, - parents=parents, - add_help=add_help, - ) +def get_args_parser(add_help: bool = True): + parser = argparse.ArgumentParser("DINOv2 training", add_help=add_help) + parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file") parser.add_argument( - "--query-dataset", - dest="query_dataset_str", - type=str, - help="Query dataset", + "opts", + help=""" +Modify config options at the end of the command. For Yacs configs, use +space-separated "PATH.KEY VALUE" pairs. +For python-based LazyConfig, use "path.key=value". + """.strip(), + default=None, + nargs=argparse.REMAINDER, ) parser.add_argument( - "--test-dataset", - dest="test_dataset_str", + "--output-dir", + "--output_dir", + default="./output", type=str, - help="Test dataset", - ) - parser.add_argument( - "--nb_knn", - nargs="+", - type=int, - help="Number of NN to use. 20 is usually working the best.", - ) - parser.add_argument( - "--temperature", - type=float, - help="Temperature used in the voting coefficient", - ) - parser.add_argument( - "--gather-on-cpu", - action="store_true", - help="Whether to gather the query features on cpu, slower" - "but useful to avoid OOM for large datasets (e.g. ImageNet22k).", - ) - parser.add_argument( - "--batch-size", - type=int, - help="Batch size.", - ) - parser.add_argument( - "--n-per-class-list", - nargs="+", - type=int, - help="Number to take per class", - ) - parser.add_argument( - "--n-tries", - type=int, - help="Number of tries", - ) - parser.set_defaults( - query_dataset_str="ImageNet:split=QUERY", - test_dataset_str="ImageNet:split=TEST", - nb_knn=[10, 20, 100, 200], - temperature=0.07, - batch_size=256, - n_per_class_list=[-1], - n_tries=1, + help="Output directory to save logs and checkpoints", ) + return parser @@ -191,7 +148,8 @@ def __init__(self, keys): def forward(self, features_dict, targets): for k in self.keys: features_dict = features_dict[k] - return {"preds": features_dict, "target": targets} + preds = features_dict / features_dict.sum(dim=-1).unsqueeze(-1) + return {"preds": preds, "target": targets} def create_module_dict(*, module, n_per_class_list, n_tries, nb_knn, query_features, query_labels): @@ -271,8 +229,9 @@ def eval_knn( header="Query", verbose=verbose, ) + # given model went through ModelWithNormalize, query_features are already normalized if verbose: - logger.info(f"Query features created, shape {query_features.shape}.") + logger.info(f"Query features created, shape {tuple(query_features.shape)}.") test_dataloader = make_data_loader( dataset=test_dataset, @@ -376,9 +335,9 @@ def eval_knn_with_model( results_dict[f"{k} Accuracy"] = acc results_dict[f"{k} AUC"] = auc if model_name and verbose: - logger.info(f"{model_name.title()} | {k}-NN classifier result: Accuracy: {acc:.2f} | AUC: {auc:.2f}") + logger.info(f"{model_name.title()} | {k}-NN classifier result: Accuracy: {acc:.2f} | AUC: {auc:.5f}") elif verbose: - logger.info(f"{k}-NN classifier result: Accuracy: {acc:.2f} | AUC: {auc:.2f}") + logger.info(f"{k}-NN classifier result: Accuracy: {acc:.2f} | AUC: {auc:.5f}") metrics_file_path = Path(output_dir, "results_eval_knn.json") with open(metrics_file_path, "a") as f: @@ -394,29 +353,77 @@ def eval_knn_with_model( def main(args): - model, autocast_dtype = setup_and_build_model(args) + cfg = setup(args) + + run_distributed = torch.cuda.device_count() > 1 + if run_distributed: + gpu_id = int(os.environ["LOCAL_RANK"]) + else: + gpu_id = -1 + + if distributed.is_main_process(): + print(f"torch.cuda.device_count(): {torch.cuda.device_count()}") + run_id = datetime.datetime.now().strftime("%Y-%m-%d_%H_%M") + # set up wandb + if cfg.wandb.enable: + key = os.environ.get("WANDB_API_KEY") + wandb_run = initialize_wandb(cfg, key=key) + run_id = wandb_run.id + else: + run_id = "" + + if run_distributed: + obj = [run_id] + torch.distributed.broadcast_object_list(obj, 0, device=torch.device(f"cuda:{gpu_id}")) + run_id = obj[0] + + output_dir = Path(cfg.train.output_dir, run_id) + if distributed.is_main_process(): + output_dir.mkdir(exist_ok=True, parents=True) + cfg.train.output_dir = str(output_dir) + + if distributed.is_main_process(): + write_config(cfg, cfg.train.output_dir) + + model, autocast_dtype = setup_and_build_model(cfg) + + transform = make_classification_eval_transform() + query_dataset_str = cfg.data.query_dataset + test_dataset_str = cfg.data.test_dataset + query_dataset = make_dataset( + dataset_str=query_dataset_str, + transform=transform, + ) + test_dataset = make_dataset( + dataset_str=test_dataset_str, + transform=transform, + ) + eval_knn_with_model( model=model, - output_dir=args.output_dir, - query_dataset_str=args.query_dataset_str, - test_dataset_str=args.test_dataset_str, - nb_knn=args.nb_knn, - temperature=args.temperature, + output_dir=cfg.train.output_dir, + query_dataset=query_dataset, + test_dataset=test_dataset, + nb_knn=cfg.knn.nb_knn, + temperature=cfg.knn.temperature, autocast_dtype=autocast_dtype, accuracy_averaging=AccuracyAveraging.MEAN_ACCURACY, gpu_id=-1, - transform=None, - gather_on_cpu=args.gather_on_cpu, - batch_size=args.batch_size, - num_workers=5, - n_per_class_list=args.n_per_class_list, - n_tries=args.n_tries, + gather_on_cpu=cfg.speed.gather_on_cpu, + batch_size=cfg.data.batch_size, + num_workers=cfg.speed.num_workers, + n_per_class_list=cfg.knn.n_per_class_list, + n_tries=cfg.knn.n_tries, + verbose=True, ) + return 0 if __name__ == "__main__": - description = "DINOv2 k-NN evaluation" - args_parser = get_args_parser(description=description) - args = args_parser.parse_args() - sys.exit(main(args)) + import warnings + + warnings.filterwarnings("ignore", category=UserWarning, message="TypedStorage is deprecated") + + args = get_args_parser(add_help=True).parse_args() + main(args) diff --git a/dinov2/eval/setup.py b/dinov2/eval/setup.py index 959128c06..de162d43b 100644 --- a/dinov2/eval/setup.py +++ b/dinov2/eval/setup.py @@ -10,7 +10,6 @@ import torch.backends.cudnn as cudnn from dinov2.models import build_model_from_cfg -from dinov2.utils.config import setup import dinov2.utils.utils as dinov2_utils @@ -59,17 +58,16 @@ def get_autocast_dtype(config): return torch.float -def build_model_for_eval(config, pretrained_weights): +def build_model_for_eval(config): model, _ = build_model_from_cfg(config, only_teacher=True) - dinov2_utils.load_pretrained_weights(model, pretrained_weights, "teacher") + dinov2_utils.load_pretrained_weights(model, config.student.pretrained_weights, "teacher") model.eval() model.cuda() return model -def setup_and_build_model(args) -> Tuple[Any, torch.dtype]: +def setup_and_build_model(config) -> Tuple[Any, torch.dtype]: cudnn.benchmark = True - config = setup(args) - model = build_model_for_eval(config, args.pretrained_weights) + model = build_model_for_eval(config) autocast_dtype = get_autocast_dtype(config) return model, autocast_dtype diff --git a/dinov2/eval/utils.py b/dinov2/eval/utils.py index 3725d97a8..2e6138136 100644 --- a/dinov2/eval/utils.py +++ b/dinov2/eval/utils.py @@ -69,6 +69,7 @@ def evaluate( header = "Test" for samples, targets, *_ in metric_logger.log_every(data_loader, 10, device, header): + # given model went through ModelWithNormalize, outputs are already normalized outputs = model(samples.to(device)) targets = targets.to(device) one_hot_targets = one_hot(targets, num_classes=num_classes) @@ -139,8 +140,6 @@ def extract_features_with_dataloader( labels_shape = list(labels_rank.shape) labels_shape[0] = sample_count all_labels = torch.full(labels_shape, fill_value=-1, device=gather_device) - if verbose: - logger.info(f"Storing features into tensor of shape {features.shape}") # share indexes, features and labels between processes index_all = all_gather_and_flatten(index).to(gather_device) diff --git a/dinov2/inference/extract_features.py b/dinov2/inference/extract_features.py index 3893114c2..96b9cfff7 100644 --- a/dinov2/inference/extract_features.py +++ b/dinov2/inference/extract_features.py @@ -139,7 +139,7 @@ def main(args): filenames.append(fname) feature_paths.append(feature_path) if cfg.wandb.enable and not run_distributed: - wandb.log({"processed": i + 1}) + wandb.log({"processed": i + imgs.shape[0]}) features_df = pd.DataFrame.from_dict( {