-
Notifications
You must be signed in to change notification settings - Fork 20
/
utils.py
137 lines (107 loc) · 5.99 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import numpy as np
import torch
import torch.nn.functional as F
import kornia as K
from PIL import Image
from CLIP.tokenizer import tokenize
def encode_text_with_prompt_ensemble(model, obj, device):
prompt_normal = ['{}', 'flawless {}', 'perfect {}', 'unblemished {}', '{} without flaw', '{} without defect', '{} without damage']
prompt_abnormal = ['damaged {}', 'broken {}', '{} with flaw', '{} with defect', '{} with damage']
prompt_state = [prompt_normal, prompt_abnormal]
prompt_templates = ['a bad photo of a {}.', 'a low resolution photo of the {}.', 'a bad photo of the {}.', 'a cropped photo of the {}.', 'a bright photo of a {}.', 'a dark photo of the {}.', 'a photo of my {}.', 'a photo of the cool {}.', 'a close-up photo of a {}.', 'a black and white photo of the {}.', 'a bright photo of the {}.', 'a cropped photo of a {}.', 'a jpeg corrupted photo of a {}.', 'a blurry photo of the {}.', 'a photo of the {}.', 'a good photo of the {}.', 'a photo of one {}.', 'a close-up photo of the {}.', 'a photo of a {}.', 'a low resolution photo of a {}.', 'a photo of a large {}.', 'a blurry photo of a {}.', 'a jpeg corrupted photo of the {}.', 'a good photo of a {}.', 'a photo of the small {}.', 'a photo of the large {}.', 'a black and white photo of a {}.', 'a dark photo of a {}.', 'a photo of a cool {}.', 'a photo of a small {}.', 'there is a {} in the scene.', 'there is the {} in the scene.', 'this is a {} in the scene.', 'this is the {} in the scene.', 'this is one {} in the scene.']
text_features = []
for i in range(len(prompt_state)):
prompted_state = [state.format(obj) for state in prompt_state[i]]
prompted_sentence = []
for s in prompted_state:
for template in prompt_templates:
prompted_sentence.append(template.format(s))
prompted_sentence = tokenize(prompted_sentence).to(device)
class_embeddings = model.encode_text(prompted_sentence)
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
class_embedding = class_embeddings.mean(dim=0)
class_embedding /= class_embedding.norm()
text_features.append(class_embedding)
text_features = torch.stack(text_features, dim=1).to(device)
return text_features
def cos_sim(a, b, eps=1e-8):
a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None]
a_norm = a / torch.clamp(a_n, min=eps)
b_norm = b / torch.clamp(b_n, min=eps)
sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1))
return sim_mt
def get_rot_mat(theta):
theta = torch.tensor(theta)
return torch.tensor([[torch.cos(theta), -torch.sin(theta), 0],
[torch.sin(theta), torch.cos(theta), 0]])
def get_translation_mat(a, b):
return torch.tensor([[1, 0, a],
[0, 1, b]])
def rot_img(x, theta):
dtype = torch.FloatTensor
rot_mat = get_rot_mat(theta)[None, ...].type(dtype).repeat(x.shape[0],1,1)
grid = F.affine_grid(rot_mat, x.size()).type(dtype)
x = F.grid_sample(x, grid, padding_mode="reflection")
return x
def translation_img(x, a, b):
dtype = torch.FloatTensor
rot_mat = get_translation_mat(a, b)[None, ...].type(dtype).repeat(x.shape[0],1,1)
grid = F.affine_grid(rot_mat, x.size()).type(dtype)
x = F.grid_sample(x, grid, padding_mode="reflection")
return x
def hflip_img(x):
x = K.geometry.transform.hflip(x)
return x
def vflip_img(x):
x = K.geometry.transform.vflip(x)
return x
def rot90_img(x,k):
# k is 0,1,2,3
degreesarr = [0., 90., 180., 270., 360]
degrees = torch.tensor(degreesarr[k])
x = K.geometry.transform.rotate(x, angle = degrees, padding_mode='reflection')
return x
def augment(fewshot_img, fewshot_mask=None):
augment_fewshot_img = fewshot_img
if fewshot_mask is not None:
augment_fewshot_mask = fewshot_mask
# rotate img
for angle in [-np.pi/8, -np.pi/16, np.pi/16, np.pi/8]:
rotate_img = rot_img(fewshot_img, angle)
augment_fewshot_img = torch.cat([augment_fewshot_img, rotate_img], dim=0)
rotate_mask = rot_img(fewshot_mask, angle)
augment_fewshot_mask = torch.cat([augment_fewshot_mask, rotate_mask], dim=0)
# translate img
for a,b in [(0.1,0.1), (-0.1,0.1), (-0.1,-0.1), (0.1,-0.1)]:
trans_img = translation_img(fewshot_img, a, b)
augment_fewshot_img = torch.cat([augment_fewshot_img, trans_img], dim=0)
trans_mask = translation_img(fewshot_mask, a, b)
augment_fewshot_mask = torch.cat([augment_fewshot_mask, trans_mask], dim=0)
# hflip img
flipped_img = hflip_img(fewshot_img)
augment_fewshot_img = torch.cat([augment_fewshot_img, flipped_img], dim=0)
flipped_mask = hflip_img(fewshot_mask)
augment_fewshot_mask = torch.cat([augment_fewshot_mask, flipped_mask], dim=0)
# vflip img
flipped_img = vflip_img(fewshot_img)
augment_fewshot_img = torch.cat([augment_fewshot_img, flipped_img], dim=0)
flipped_mask = vflip_img(fewshot_mask)
augment_fewshot_mask = torch.cat([augment_fewshot_mask, flipped_mask], dim=0)
else:
# rotate img
for angle in [-np.pi/8, -np.pi/16, np.pi/16, np.pi/8]:
rotate_img = rot_img(fewshot_img, angle)
augment_fewshot_img = torch.cat([augment_fewshot_img, rotate_img], dim=0)
# translate img
for a,b in [(0.1,0.1), (-0.1,0.1), (-0.1,-0.1), (0.1,-0.1)]:
trans_img = translation_img(fewshot_img, a, b)
augment_fewshot_img = torch.cat([augment_fewshot_img, trans_img], dim=0)
# hflip img
flipped_img = hflip_img(fewshot_img)
augment_fewshot_img = torch.cat([augment_fewshot_img, flipped_img], dim=0)
# vflip img
flipped_img = vflip_img(fewshot_img)
augment_fewshot_img = torch.cat([augment_fewshot_img, flipped_img], dim=0)
B, _, H, W = augment_fewshot_img.shape
augment_fewshot_mask = torch.zeros([B, 1, H, W])
return augment_fewshot_img, augment_fewshot_mask