-
Notifications
You must be signed in to change notification settings - Fork 14
/
classify_imagenet.py
121 lines (97 loc) · 3.72 KB
/
classify_imagenet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import numpy as np
import torch
import clip
from pkg_resources import packaging
from imagenet_prompts.standard_image_prompts import imagenet_templates
import pdb
from collections import defaultdict
from imagenetdataset import ImagenetDataset
from PIL import Image
import PIL
import json
from tqdm import tqdm
PATH_TO_IMAGENET = "../val"
PATH_TO_PROMPTS = "./imagenet_prompts/CuPL_image_prompts.json"
model, preprocess = clip.load("ViT-L/14")
model.eval()
all_images = ImagenetDataset(PATH_TO_IMAGENET, transform=preprocess)
loader = torch.utils.data.DataLoader(all_images, batch_size=512, num_workers=8)
def zeroshot_classifier(classnames, textnames, templates):
with torch.no_grad():
zeroshot_weights = []
i = 0
for classname in tqdm(classnames):
texts = [template.format(textnames[i]) for template in templates] #format with class
texts = clip.tokenize(texts).cuda() #tokenize
class_embeddings = model.encode_text(texts) #embed with text encoder
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
class_embedding = class_embeddings.mean(dim=0)
class_embedding /= class_embedding.norm()
zeroshot_weights.append(class_embedding)
i += 1
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
return zeroshot_weights
def zeroshot_classifier_gpt(classnames, textnames, templates, use_both):
with open(PATH_TO_PROMPTS) as f:
gpt3_prompts = json.load(f)
with torch.no_grad():
zeroshot_weights = []
i = 0
for classname in tqdm(classnames):
if use_both:
texts = [template.format(textnames[i]) for template in templates]
else:
texts = []
for t in gpt3_prompts[textnames[i]]:
texts.append(t)
texts = clip.tokenize(texts, truncate=True).cuda() #tokenize
class_embeddings = model.encode_text(texts) #embed with text encoder
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
class_embedding = class_embeddings.mean(dim=0)
class_embedding /= class_embedding.norm()
zeroshot_weights.append(class_embedding)
i += 1
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
return zeroshot_weights
print("\nCreating standard text embeddings...")
zeroshot_weights_base = zeroshot_classifier(all_images.idx_to_label, all_images.idx_to_text, imagenet_templates)
print("Done.\n")
print("Creating CuPL text embeddings...")
zeroshot_weights_cupl = zeroshot_classifier_gpt(all_images.idx_to_label, all_images.idx_to_text, imagenet_templates, False)
print("Done.\n")
print("Creating combined text embeddings...")
zeroshot_weights_gpt_both = zeroshot_classifier_gpt(all_images.idx_to_label, all_images.idx_to_text, imagenet_templates, True)
print("Done.\n")
total = 0.
correct_base = 0.
correct_cupl = 0.
correct_both = 0.
print("Classifying ImageNet...")
with torch.no_grad():
for i, (images, target, num) in enumerate(tqdm(loader)):
images = images.cuda()
target = target.cuda()
# predict
image_features = model.encode_image(images)
image_features /= image_features.norm(dim=-1, keepdim=True)
logits_base = image_features @ zeroshot_weights_base
logits_cupl = image_features @ zeroshot_weights_cupl
logits_both = image_features @ zeroshot_weights_gpt_both
pred_base = torch.argmax(logits_base, dim =1)
pred_cupl = torch.argmax(logits_cupl, dim =1)
pred_both = torch.argmax(logits_both, dim =1)
for j in range(len(target)):
total += 1.
if pred_base[j] == target[j]:
correct_base += 1.
if pred_cupl[j] == target[j]:
correct_cupl += 1.
if pred_both[j] == target[j]:
correct_both += 1.
print()
top1 = (correct_base / total) * 100
print(f"Top-1 accuracy standard: {top1:.2f}")
top1 = (correct_cupl / total) * 100
print(f"Top-1 accuracy CuPL: {top1:.2f}")
top1 = (correct_both / total) * 100
print(f"Top-1 accuracy both: {top1:.2f}")