-
Notifications
You must be signed in to change notification settings - Fork 0
/
objective_grad.py
262 lines (221 loc) · 10.2 KB
/
objective_grad.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
import torch
import util
def compute_grad_A(C, Q, R, Lambda, gamma, \
semiRelaxedLeft, semiRelaxedRight, device, \
Wasserstein=True, FGW=False, A=None, B=None, \
alpha=0.0, unbalanced=False, \
dtype=torch.float64, full_grad=True):
'''
Code for computing the Wasserstein, Gromov-Wasserstein, and fused Gromov-Wasserstein gradients with respect to Q and R.
These depend on the marginal relaxation of the specific OT problem you want to solve, due to proportionality simplifications.
------Parameters------
C: torch.tensor (N1 x N2)
A matrix of pairwise feature distances in space X and space Y (inter-space).
Q: torch.tensor (N1 x r)
The left sub-coupling matrix.
R: torch.tensor (N2 x r)
The right sub-coupling matrix.
Lambda: torch.tensor (r x r)
The inner transition matrix.
gamma: float
The mirror-descent step-size.
semiRelaxedLeft: bool
True if relaxing the left marginal.
semiRelaxedRight: bool
True if relaxing the right marginal.
device: str
The device (i.e. 'cpu' or 'cuda')
Wasserstein: bool
True if using the Wasserstein loss <C, P>_F as the objective cost,
else runs GW if FGW false and FGW if GW true.
FGW: bool
True if running the Fused-Gromov Wasserstein problem, and otherwise false.
A: torch.tensor (N1 x N1)
Pairwise distance matrix in metric space X.
B: torch.tensor (N2 x N2)
Pairwise distance matrix in metric space Y.
alpha: float
A balance parameter between the Wasserstein term and
the Gromov-Wasserstein term of the objective.
unbalanced: bool
True if running the unbalanced problem;
if semiRelaxedLeft/Right and unbalanced False (default) then running the balanced problem.
'''
r = Lambda.shape[0]
one_r = torch.ones((r), device=device, dtype=dtype)
One_rr = torch.outer(one_r, one_r).to(device)
if Wasserstein:
gradQ, gradR = Wasserstein_Grad(C, Q, R, Lambda, device, \
dtype=torch.float64, full_grad=full_grad)
elif A is not None and B is not None:
if not semiRelaxedLeft and not semiRelaxedRight and not unbalanced:
# Balanced gradient (Q1_r = a AND R1_r = b)
gradQ = - 4 * (A@Q)@Lambda@(R.T@B@R)@Lambda.T
gradR = - 4 * (B@[email protected])@(Q.T@A@Q)@Lambda
elif semiRelaxedRight:
# Semi-relaxed right marginal gradient (Q1_r = a)
gradQ = - 4 * (A@Q)@Lambda@(R.T@B@R)@Lambda.T
gradR = 2*B**2 @ R @ One_rr - 4*(B@[email protected])@(Q.T@A@Q)@Lambda
elif semiRelaxedLeft:
# Semi-relaxed left marginal gradient (R1_r = b)
gradQ = 2*A**2 @ Q @ One_rr - 4 * (A@Q)@Lambda@(R.T@B@R)@Lambda.T
gradR = - 4 * (B@[email protected])@(Q.T@A@Q)@Lambda
elif unbalanced:
# Fully unbalanced with no marginal constraints
gradQ = 2*A**2 @ Q @ One_rr - 4 * (A@Q)@Lambda@(R.T@B@R)@Lambda.T
gradR = 2*B**2 @ R @ One_rr - 4 * (B@[email protected])@(Q.T@A@Q)@Lambda
if full_grad:
N1, N2 = Q.shape[0], R.shape[0]
one_N1, one_N2 = torch.ones((N1), device=device, dtype=dtype), torch.ones((N2), device=device, dtype=dtype)
gQ, gR = Q.T @ one_N1, R.T @ one_N2
F = (Q@[email protected])
MR = Lambda.T @ Q.T @ A @ F @ B @ R @ torch.diag(1/gR)
MQ = Lambda @ R.T @ B @ F.T @ A @ Q @ torch.diag(1/gQ)
gradQ += 4*torch.outer(one_N1, torch.diag(MQ))
gradR += 4*torch.outer(one_N2, torch.diag(MR))
# Readjust cost for FGW problem
if FGW:
gradQW, gradRW = Wasserstein_Grad(C, Q, R, Lambda, device, \
dtype=torch.float64, full_grad=full_grad)
gradQ = (1-alpha)*gradQW + alpha*gradQ
gradR = (1-alpha)*gradRW + alpha*gradR
else:
raise ValueError("---Input either Wasserstein=True or provide distance matrices A and B for GW problem---")
normalizer = torch.max(torch.tensor([torch.max(torch.abs(gradQ)) , torch.max(torch.abs(gradR))]))
gamma_k = gamma / normalizer
return gradQ, gradR, gamma_k
def compute_grad_B(C, Q, R, Lambda, gQ, gR, gamma, device, Wasserstein=True, \
FGW=False, A=None, B=None, alpha=0.0, \
dtype=torch.float64):
'''
Code for computing the Wasserstein, Gromov-Wasserstein, and fused Gromov-Wasserstein gradients with respect to T.
------Parameters------
C: torch.tensor (N1 x N2)
A matrix of pairwise feature distances in space X and space Y (inter-space).
Q: torch.tensor (N1 x r)
The left sub-coupling matrix.
R: torch.tensor (N2 x r)
The right sub-coupling matrix.
Lambda: torch.tensor (r x r)
The inner transition matrix.
gQ: torch.tensor (r)
The inner marginal corresponding to the matrix Q.
gR: torch.tensor (r)
The inner marginal corresponding to the matrix R.
gamma: float
The mirror-descent step-size.
device: str
The device (i.e. 'cpu' or 'cuda')
Wasserstein: bool
True if using the Wasserstein loss <C, P>_F as the objective cost,
else runs GW if FGW false and FGW if GW true.
FGW: bool
True if running the Fused-Gromov Wasserstein problem, and otherwise false.
A: torch.tensor (N1 x N1)
Pairwise distance matrix in metric space X.
B: torch.tensor (N2 x N2)
Pairwise distance matrix in metric space Y.
alpha: float
A balance parameter between the Wasserstein term and
the Gromov-Wasserstein term of the objective.
'''
if Wasserstein:
gradLambda = Q.T @ C @ R
else:
gradLambda = -4 * Q.T @ A @ Q @ Lambda @ R.T @ B @ R
if FGW:
gradLambda = (1-alpha)*(Q.T @ C @ R) + alpha*gradLambda
gradT = torch.diag(1/gQ) @ gradLambda @ torch.diag(1/gR) # (mass-reweighted form)
gamma_T = gamma / torch.max(torch.abs(gradT))
return gradT, gamma_T
def Wasserstein_Grad(C, Q, R, Lambda, device, \
dtype=torch.float64, full_grad=True):
gradQ = (C @ R) @ Lambda.T
if full_grad:
# rank-one perturbation
N1 = Q.shape[0]
one_N1 = torch.ones((N1), device=device, dtype=dtype)
gQ = Q.T @ one_N1
w1 = torch.diag( (gradQ.T @ Q) @ torch.diag(1/gQ) )
gradQ -= torch.outer(one_N1, w1)
# linear term
gradR = (C.T @ Q) @ Lambda
if full_grad:
# rank-one perturbation
N2 = R.shape[0]
one_N2 = torch.ones((N2), device=device, dtype=dtype)
gR = R.T @ one_N2
w2 = torch.diag( torch.diag(1/gR) @ (R.T @ gradR) )
gradR -= torch.outer(one_N2, w2)
return gradQ, gradR
'''
--------------
Code for gradients assuming low-rank distance matrices C, A, B
--------------
'''
def compute_grad_A_LR(C_factors, A_factors, B_factors, Q, R, Lambda, gamma, device, \
alpha=0.0, dtype=torch.float64, full_grad=False):
r = Lambda.shape[0]
one_r = torch.ones((r), device=device, dtype=dtype)
One_rr = torch.outer(one_r, one_r).to(device)
N1, N2 = C_factors[0].size(0), C_factors[1].size(1)
A1, A2 = A_factors[0], A_factors[1]
B1, B2 = B_factors[0], B_factors[1]
# A*2's low-rank factorization
A1_tild, A2_tild = util.hadamard_square_lr(A1, A2.T, device=device)
# GW gradients for balanced marginal cases
gradQ = - 4 * (A1 @ (A2 @ (Q @ Lambda@( (R.T@ B1) @ (B2 @R) )@Lambda.T)) )
gradR = - 4 * (B1 @ (B2 @ (R @ (Lambda.T@( (Q.T @ A1) @ ( A2 @ Q ))@Lambda)) ) )
one_N1, one_N2 = torch.ones((N1), device=device, dtype=dtype), torch.ones((N2), device=device, dtype=dtype)
if full_grad:
# Rank-1 GW perturbation
N1, N2 = Q.shape[0], R.shape[0]
gQ, gR = Q.T @ one_N1, R.T @ one_N2
MR = Lambda.T @ ( (Q.T @ A1) @ (A2 @ Q) ) @ Lambda @ ((R.T @ B1) @ (B2 @ R)) @ torch.diag(1/gR)
MQ = Lambda @ ( (R.T @ B1) @ (B2 @ R) ) @ Lambda.T @ ((Q.T @ A1) @ (A2 @ Q) ) @ torch.diag(1/gQ)
gradQ += 4*torch.outer(one_N1, torch.diag(MQ))
gradR += 4*torch.outer(one_N2, torch.diag(MR))
gQ, gR = Q.T @ one_N1, R.T @ one_N2
# total gradients -- readjust cost for FGW problem by adding W gradients
gradQW, gradRW = Wasserstein_Grad_LR(C_factors, Q, R, Lambda, device, \
dtype=dtype, full_grad=full_grad)
gradQ = (1-alpha)*gradQW + (alpha/2)*gradQ
gradR = (1-alpha)*gradRW + (alpha/2)*gradR
normalizer = torch.max(torch.tensor([torch.max(torch.abs(gradQ)) , torch.max(torch.abs(gradR))]))
gamma_k = gamma / normalizer
return gradQ, gradR, gamma_k
def compute_grad_B_LR(C_factors, A_factors, B_factors, Q, R, Lambda, gQ, gR, gamma, device, \
alpha=0.0, dtype=torch.float64):
N1, N2 = C_factors[0].size(0), C_factors[1].size(1)
A1, A2 = A_factors[0], A_factors[1]
B1, B2 = B_factors[0], B_factors[1]
# GW grad
gradLambda = -4 * ( (Q.T @ A1) @ (A2 @ Q) ) @ Lambda @ ( (R.T @ B1) @ (B2 @ R) )
del A1,A2,B1,B2
C1, C2 = C_factors[0], C_factors[1]
# total grad
gradLambda = (1-alpha)*( (Q.T @ C1) @ (C2 @ R) ) + (alpha/2)*gradLambda
gradT = torch.diag(1/gQ) @ gradLambda @ torch.diag(1/gR) # (mass-reweighted form)
gamma_T = gamma / torch.max(torch.abs(gradT))
return gradT, gamma_T
def Wasserstein_Grad_LR(C_factors, Q, R, Lambda, device, \
dtype=torch.float64, full_grad=True):
C1, C2 = C_factors[0], C_factors[1]
gradQ = C1 @ ((C2 @ R) @ Lambda.T)
if full_grad:
# rank-one perturbation
N1 = Q.shape[0]
one_N1 = torch.ones((N1), device=device, dtype=dtype)
gQ = Q.T @ one_N1
w1 = torch.diag( (gradQ.T @ Q) @ torch.diag(1/gQ) )
gradQ -= torch.outer(one_N1, w1)
# linear term
gradR = C2.T @ ((C1.T @ Q) @ Lambda)
if full_grad:
# rank-one perturbation
N2 = R.shape[0]
one_N2 = torch.ones((N2), device=device, dtype=dtype)
gR = R.T @ one_N2
w2 = torch.diag( torch.diag(1/gR) @ (R.T @ gradR) )
gradR -= torch.outer(one_N2, w2)
return gradQ, gradR