-
Notifications
You must be signed in to change notification settings - Fork 1
/
utils.py
137 lines (98 loc) · 4.94 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import random
import numpy as np
from itertools import combinations
import torch
import torch.nn.functional as F
random.seed(0)
def get_random_triplets(embeddings) -> torch.Tensor:
'''
For each image in data (Anchor), randomly sample a Positive image from its class.
Then from each of the other classes sample one image (randomly) as the Negative. Hence, for every Anchor
you will have 1 randomly selected positive from it's class and randomly selected Negative from each of the n-1 classes
where n is total number of classes. For every Anchor image you would have n-1 triplets.
So if you're having 3 classes of 10 images each then you would have 60 triplets.
'''
triplets = []
for i, embedding in enumerate(embeddings):
temp = embeddings.pop(i)
for anchor in embedding:
positive = random.choice(embedding)
for negatives in embeddings:
negative = random.choice(negatives)
triplets.append(torch.stack([anchor, positive, negative], dim=0))
embeddings.insert(i, temp)
return torch.stack(triplets, dim=0)
def triplet_loss(anchor, positive, negative, margin=1):
pos_dist = (anchor - positive).pow(2).sum(-1) #.pow(.5)
neg_dist = (anchor - negative).pow(2).sum(-1) #.pow(.5)
loss = F.relu(pos_dist - neg_dist + margin)
return loss.mean()
# ref: https://github.com/adambielski/siamese-triplet/blob/master/losses.py#L24
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
#################################################
#### CODE BEYOND THIS LINE IS NO LONGER USED ####
#################################################
def __get_random_triplets(no_classes:int, images, target, no_triplets:int):
# Ref : https://github.com/tamerthamoqa/facenet-pytorch-vggface2/blob/master/datasets/TripletLossDataset.py#L76-L136
randomstate = np.random.RandomState(seed=None)
# no_class = len(lfw_people.target_names)
triplets = []
class_pairs = []
# progress_bar = tqdm(range(no_triplets), desc='fetching triplets')
for progress in range(no_triplets):
pos_class = randomstate.choice(no_classes)
neg_class = randomstate.choice(no_classes)
while pos_class == neg_class:
neg_class = randomstate.choice(no_classes)
# pos_name = lfw_people.target_names[pos_class]
# neg_name = lfw_people.target_names[neg_class]
pos_imgs = images[target == pos_class]
neg_imgs = images[target == neg_class]
if pos_imgs.shape[0] == 2:
ianc, ipos = 0,1
else:
ianc = randomstate.randint(0, pos_imgs.shape[0])
ipos = randomstate.randint(0, pos_imgs.shape[0])
while ianc == ipos:
ipos = randomstate.randint(0, pos_imgs.shape[0])
ineg = randomstate.randint(0, neg_imgs.shape[0])
triplets.append(
torch.stack([
torch.from_numpy(pos_imgs[ianc] / 255),
torch.from_numpy(pos_imgs[ipos] / 255),
torch.from_numpy(neg_imgs[ineg] / 255)
]))
class_pairs.append((pos_class, neg_class))
return torch.stack(triplets) , class_pairs
def __get_all_tensor_triplets(embeddings:list, targets:list) -> torch.Tensor:
'''
The reason for not consider all possible triplets:
1. Since we are considering all possible triplets, there might
chances of overfitting.
2. In larger dataset this is a bad practice i.e getting all possible
triplets isn't feasible and also model will memorize the data
resulting in overfitting.
'''
'''
Parameters
----------
embeddings : list of torch.tensor each of shape torch.Size([?, 128])
targets : list of ints
Returns
-------
triplets : torch.tensor of shape torch.Size([?, no(triplets), 3])
'''
# eg : no(targets) = 3
# eg : no(embeds) = 10
assert len(embeddings) == len(targets) , "Embeddings and Targets must have same lenght"
triplets = []
for i, anchor in enumerate(embeddings):
positive_pairs = list(combinations(anchor,2)) # this will give of distinct pairs of elements
# no(pos_pairs) = no(targets)P2 / 2! # eg : 45
temp = embeddings.pop(i) # embeddings list except the anchor
for negative_embeddings in torch.cat(embeddings,dim=0): # loops runs for no(targets)-1 * no(embeds) times # eg : (3-1)*10=20
triple = [torch.stack([positive_pair[0], positive_pair[1], negative_embeddings], dim=0) for positive_pair in positive_pairs] # no(triple) = n(pos_pair) # eg: 45
triplets.extend(triple) # no. of triplets added = no(pos_pairs)*(no(targets) - 1)*(no(embeds)) # eg: 45*2*10=900
embeddings.insert(i, temp)
return torch.stack(triplets, dim=0) # no(triplets) = no(added_triples) * no(embeds) # eg: 900*3 = 2700