Skip to content

Commit

Permalink
added code for feature extraction: new dataset based on ImageFolder, …
Browse files Browse the repository at this point in the history
…new sampler type, new PatchEmbedder model, new config file
  • Loading branch information
clemsgrs committed Mar 7, 2024
1 parent a265bb7 commit 4b80575
Show file tree
Hide file tree
Showing 6 changed files with 288 additions and 0 deletions.
27 changes: 27 additions & 0 deletions dinov2/configs/inference/vits14.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
dino:
head_bottleneck_dim: 384
train:
centering: sinkhorn_knopp
inference:
data_dir: /root/data
batch_size: 64
num_workers: 8
student:
arch: vit_small
patch_size: 14
num_register_tokens: 0
pretrained_weights: '/data/pathology/projects/ais-cap/clement/code/dinov2/output/769naczt/eval/training_649999/teacher_checkpoint.pth'
drop_path_rate: 0.4
ffn_layer: swiglufused
block_chunks: 4
crops:
local_crops_size: 98
wandb:
enable: false
project: 'vision'
username: 'vlfm'
exp_name: 'feature_extraction'
tags: ['${wandb.exp_name}', 'patch', '${student.arch}']
dir: '/home/user'
group:
resume_id:
1 change: 1 addition & 0 deletions dinov2/data/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@
from .pathology import PathologyDataset
from .knn import KNNDataset
from .foundation import PathologyFoundationDataset
from .image_folder import ImageFolderWithNameDataset
30 changes: 30 additions & 0 deletions dinov2/data/datasets/image_folder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from pathlib import Path
from torchvision import datasets
from typing import Callable, Optional


class ImageFolderWithNameDataset(datasets.ImageFolder):
def __init__(
self,
root: str,
transform: Optional[Callable] = None,
):
super().__init__(
root,
transform,
)

def __getitem__(self, idx: int):
"""
Args:
idx (int): index
Returns:
tuple: (sample, target) where target is class_index of the target class.
"""
path, _ = self.samples[idx]
fname = Path(path).stem
sample = self.loader(path)
if self.transform is not None:
sample = self.transform(sample)
return sample, fname
9 changes: 9 additions & 0 deletions dinov2/data/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class SamplerType(Enum):
INFINITE = 2
SHARDED_INFINITE = 3
SHARDED_INFINITE_NEW = 4
RANDOM = 5


def _make_bool_str(b: bool) -> str:
Expand Down Expand Up @@ -177,6 +178,14 @@ def _make_sampler(
seed=seed,
drop_last=False,
)
elif type == SamplerType.RANDOM:
if verbose:
logger.info("sampler: random")
if size > 0:
raise ValueError("sampler size > 0 is invalid")
if advance > 0:
raise ValueError("sampler advance > 0 is invalid")
return torch.utils.data.RandomSampler(dataset)

if verbose:
logger.info("sampler: none")
Expand Down
179 changes: 179 additions & 0 deletions dinov2/inference/extract_features.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
import os
import tqdm
import torch
import wandb
import argparse
import datetime
import pandas as pd
import multiprocessing as mp

from pathlib import Path

import dinov2.distributed as distributed

from dinov2.models import PatchEmbedder
from dinov2.utils.config import setup, write_config
from dinov2.utils.utils import initialize_wandb
from dinov2.data import SamplerType, make_data_loader
from dinov2.data.datasets import ImageFolderWithNameDataset
from dinov2.data.transforms import make_classification_eval_transform


def get_args_parser(add_help: bool = True):
parser = argparse.ArgumentParser("DINOv2 training", add_help=add_help)
parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file")
parser.add_argument(
"opts",
help="""
Modify config options at the end of the command. For Yacs configs, use
space-separated "PATH.KEY VALUE" pairs.
For python-based LazyConfig, use "path.key=value".
""".strip(),
default=None,
nargs=argparse.REMAINDER,
)
parser.add_argument(
"--output-dir",
"--output_dir",
default="",
type=str,
help="Output directory to save logs and checkpoints",
)

return parser


def main(args):
cfg = setup(args)

run_distributed = torch.cuda.device_count() > 1
if run_distributed:
gpu_id = int(os.environ["LOCAL_RANK"])
else:
gpu_id = -1

if distributed.is_main_process():
print(f"torch.cuda.device_count(): {torch.cuda.device_count()}")
run_id = datetime.datetime.now().strftime("%Y-%m-%d_%H_%M")
# set up wandb
if cfg.wandb.enable:
key = os.environ.get("WANDB_API_KEY")
wandb_run = initialize_wandb(cfg, key=key)
run_id = wandb_run.id
else:
run_id = ""

if run_distributed:
obj = [run_id]
torch.distributed.broadcast_object_list(obj, 0, device=torch.device(f"cuda:{gpu_id}"))
run_id = obj[0]

output_dir = Path(cfg.train.output_dir, run_id)
if distributed.is_main_process():
output_dir.mkdir(exist_ok=True, parents=True)
cfg.train.output_dir = str(output_dir)
features_dir = Path(output_dir, "features")
features_dir.mkdir(exist_ok=True)

if distributed.is_main_process():
write_config(cfg, cfg.train.output_dir)

model = PatchEmbedder(
cfg,
verbose=distributed.is_main_process(),
)

transform = make_classification_eval_transform()
dataset = ImageFolderWithNameDataset(cfg.inference.data_dir, transform)

if run_distributed:
sampler_type = SamplerType.DISTRIBUTED
else:
sampler_type = SamplerType.RANDOM

num_workers = min(mp.cpu_count(), cfg.inference.num_workers)
if "SLURM_JOB_CPUS_PER_NODE" in os.environ:
num_workers = min(num_workers, int(os.environ["SLURM_JOB_CPUS_PER_NODE"]))

data_loader = make_data_loader(
dataset=dataset,
batch_size=cfg.inference.batch_size,
num_workers=num_workers,
sampler_type=sampler_type,
drop_last=False,
shuffle=False,
)

if gpu_id == -1:
device = torch.device("cuda")
else:
device = torch.device(f"cuda:{gpu_id}")

model = model.to(device, non_blocking=True)
model.eval()

if distributed.is_main_process():
print()

filenames, feature_paths = [], []

with tqdm.tqdm(
data_loader,
desc="Feature Extraction",
unit=" img",
ncols=80,
position=0,
leave=True,
disable=not (gpu_id in [-1, 0]),
) as t1:
with torch.no_grad():
for i, batch in enumerate(t1):
imgs, fnames = batch
imgs = imgs.to(device, non_blocking=True)
features = model(imgs)
for k, f in enumerate(features):
fname = fnames[k]
feature_path = Path(features_dir, f"{fname}.pt")
torch.save(f, feature_path)
filenames.append(fname)
feature_paths.append(feature_path)
if cfg.wandb.enable and not run_distributed:
wandb.log({"processed": i + 1})

features_df = pd.DataFrame.from_dict(
{
"filename": filenames,
"feature_path": feature_paths,
}
)

if run_distributed:
features_csv_path = Path(output_dir, f"features_{gpu_id}.csv")
else:
features_csv_path = Path(output_dir, "features.csv")
features_df.to_csv(features_csv_path, index=False)

if run_distributed:
torch.distributed.barrier()
if distributed.is_main_process():
dfs = []
for gpu_id in range(torch.cuda.device_count()):
fp = Path(output_dir, f"features_{gpu_id}.csv")
df = pd.read_csv(fp)
dfs.append(df)
os.remove(fp)
features_df = pd.concat(dfs, ignore_index=True)
features_df = features_df.drop_duplicates()
features_df.to_csv(Path(output_dir, "features.csv"), index=False)

if cfg.wandb.enable and distributed.is_main_process() and run_distributed:
wandb.log({"processed": len(features_df)})


if __name__ == "__main__":
import warnings

warnings.filterwarnings("ignore", category=UserWarning, message="TypedStorage is deprecated")

args = get_args_parser(add_help=True).parse_args()
main(args)
42 changes: 42 additions & 0 deletions dinov2/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
# found in the LICENSE file in the root directory of this source tree.

import logging
import torch
import torch.nn as nn

from pathlib import Path

from . import vision_transformer as vits

Expand Down Expand Up @@ -56,3 +60,41 @@ def update_state_dict(model_dict, state_dict):
success += 1
msg = f"{success} weight(s) loaded succesfully ; {failure} weight(s) not loaded because of mismatching shapes"
return updated_state_dict, msg


class PatchEmbedder(nn.Module):
def __init__(
self,
cfg,
verbose: bool = True,
):
super(PatchEmbedder, self).__init__()
checkpoint_key = "teacher"

self.vit, _, _ = build_model_from_cfg(cfg)

if Path(cfg.student.pretrained_weights).is_file():
if verbose:
print(f"Pretrained weights: loading from {cfg.student.pretrained_weights}")
chkpt = torch.load(cfg.student.pretrained_weights)
sd = chkpt[checkpoint_key]
sd, msg = update_state_dict(self.vit.state_dict(), sd)
self.vit.load_state_dict(sd, strict=False)
if verbose:
print(f"Pretrained weights loaded: {msg}")

elif verbose:
print(f"{cfg.student.pretrained_weights} doesnt exist ; please provide path to existing file")

if verbose:
print("Freezing Vision Transformer")
for param in self.vit.parameters():
param.requires_grad = False
if verbose:
print("Done")

def forward(self, x):
# x = [B, 3, img_size, img_size]
# TODO: add prepare_img_tensor method
feature = self.vit(x).detach().cpu() # [B, 384]
return feature

0 comments on commit 4b80575

Please sign in to comment.