-
Notifications
You must be signed in to change notification settings - Fork 1
/
MuMoE_test.py
122 lines (94 loc) · 4.64 KB
/
MuMoE_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from MuMoE import MuMoE, CPMuMoE, TRMuMoE
import tensorly as tl
from torch import nn
from einops import einsum
import torch
tl.set_backend('pytorch')
from entmax import entmax15
act = lambda x: entmax15(x, dim=-1)
n_experts = 16
batch_size = 4
in_dim = 4
out_dim = 8
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.manual_seed(0)
# test the layer's factorised forward pass computation matches that of the fully materialised tensor.
print('Testing CPMMoE...')
with torch.no_grad():
x = torch.randn(batch_size, in_dim).to(device)
###########################
########################### hierarchy E=1
###########################
rank = 16
############ model output
model = MuMoE(input_dim=in_dim, output_dim=out_dim, MuMoE_layer=CPMuMoE, expert_dims=[n_experts], ranks=rank, act=act, normalization='bn')
model.eval()
out = model(x).squeeze(1)
############ unfactorized output
W = tl.cp_tensor.cp_to_tensor((torch.ones(rank).to(device), model.MuMoE.factors))
a = model.act(model.bn[0](x @ model.projs[0])) # generate mixture weights
b = W[:, -1, :] # get the bias term
W = W[:, :-1, :] # get just the linear transformation weights
# compute the raw Eq. (1) series of tensor contractions
real_out = einsum(W, a, x, 'n i o, b n, b i -> b o') + einsum(b, a, 'n o, b n -> b o')
torch.testing.assert_close(out, real_out)
print('...(hierarchy=1) factorized==unfactorized forward pass test passed')
###########################
########################### hierarchy E=2
###########################
############ model output
model = MuMoE(input_dim=in_dim, output_dim=out_dim, MuMoE_layer=CPMuMoE, expert_dims=[n_experts, 2], ranks=rank, act=act, normalization='bn', hierarchy=2)
model.eval()
out = model(x).squeeze(1)
############ unfactorized output
W = tl.cp_tensor.cp_to_tensor((torch.ones(rank).to(device), model.MuMoE.factors))
# generate expert coefficients
a1 = model.act(model.bn[0](x @ model.projs[0]))
a2 = model.act(model.bn[1](x @ model.projs[1]))
b = W[:, :, -1, :] # get the bias term
W = W[:, :, :-1, :] # get just the linear transformation weights
# compute the raw Eq. (1) series of tensor contractions
real_out = einsum(W, a1, a2, x, 'n1 n2 i o, b n1, b n2, b i -> b o') + einsum(b, a1, a2, 'n1 n2 o, b n1, b n2 -> b o')
torch.testing.assert_close(out, real_out)
print('...(hierarchy=2) factorized==unfactorized forward pass test passed')
print('Testing TRMMoE...')
with torch.no_grad():
x = torch.randn(batch_size, in_dim).to(device)
###########################
########################### hierarchy E=1
###########################
r1 = 4 ; r2 = 4 ; r3 = 4
ranks = [[r1, n_experts, r2], [r2, in_dim, r3], [r3, out_dim, r1]]
############ model output
model = MuMoE(input_dim=in_dim, output_dim=out_dim, MuMoE_layer=TRMuMoE, expert_dims=[n_experts], ranks=ranks, act=act, normalization='bn')
model.eval()
out = model(x).squeeze(1)
############ unfactorized output
W = tl.tr_tensor.tr_to_tensor(model.MuMoE.factors)
a = model.act(model.bn[0](x @ model.projs[0])) # generate mixture weights
b = W[:, -1, :] # get the bias term
W = W[:, :-1, :] # get just the linear transformation weights
# compute the raw Eq. (1) series of tensor contractions
real_out = einsum(W, a, x, 'n i o, b n, b i -> b o') + einsum(b, a, 'n o, b n -> b o')
torch.testing.assert_close(out, real_out)
print('...(hierarchy=1) factorized==unfactorized forward pass test passed')
###########################
########################### hierarchy E=2
###########################
r1 = 4 ; r2 = 4 ; r3 = 4 ; r4 = 4
ranks = [[r1, n_experts, r2], [r2, 2, r3], [r3, in_dim, r4], [r4, out_dim, r1]]
############ model output
model = MuMoE(input_dim=in_dim, output_dim=out_dim, MuMoE_layer=TRMuMoE, expert_dims=[n_experts, 2], ranks=ranks, act=act, normalization='bn', hierarchy=2)
model.eval()
out = model(x).squeeze(1)
############ unfactorized output
W = tl.tr_tensor.tr_to_tensor(model.MuMoE.factors)
# generate expert coefficients
a1 = model.act(model.bn[0](x @ model.projs[0]))
a2 = model.act(model.bn[1](x @ model.projs[1]))
b = W[:, :, -1, :] # get the bias term
W = W[:, :, :-1, :] # get just the linear transformation weights
# compute the raw Eq. (1) series of tensor contractions
real_out = einsum(W, a1, a2, x, 'n1 n2 i o, b n1, b n2, b i -> b o') + einsum(b, a1, a2, 'n1 n2 o, b n1, b n2 -> b o')
torch.testing.assert_close(out, real_out)
print('...(hierarchy=2) factorized==unfactorized forward pass test passed')