-
Notifications
You must be signed in to change notification settings - Fork 93
/
fairseq_patch.diff
131 lines (125 loc) · 5.45 KB
/
fairseq_patch.diff
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
diff --git a/fairseq/models/transformer/transformer_decoder.py b/fairseq/models/transformer/transformer_decoder.py
index 61aaa09..458bd40 100644
--- a/fairseq/models/transformer/transformer_decoder.py
+++ b/fairseq/models/transformer/transformer_decoder.py
@@ -115,9 +115,14 @@ class TransformerDecoderBase(FairseqIncrementalDecoder):
self.layers = LayerDropModuleList(p=self.decoder_layerdrop)
else:
self.layers = nn.ModuleList([])
+
+ def config_with_index(cfg, index):
+ cfg.transformer_index = index
+ return cfg
+
self.layers.extend(
[
- self.build_decoder_layer(cfg, no_encoder_attn)
+ self.build_decoder_layer(config_with_index(cfg, _), no_encoder_attn)
for _ in range(cfg.decoder.layers)
]
)
diff --git a/fairseq/modules/transformer_layer.py b/fairseq/modules/transformer_layer.py
index 2e687b9..8c1166b 100644
--- a/fairseq/modules/transformer_layer.py
+++ b/fairseq/modules/transformer_layer.py
@@ -324,18 +324,33 @@ class TransformerDecoderLayerBase(nn.Module):
else None
)
- self.fc1 = self.build_fc1(
- self.embed_dim,
- cfg.decoder.ffn_embed_dim,
- self.quant_noise,
- self.quant_noise_block_size,
- )
- self.fc2 = self.build_fc2(
- cfg.decoder.ffn_embed_dim,
- self.embed_dim,
- self.quant_noise,
- self.quant_noise_block_size,
- )
+ self.moe_freq = int(torch.os.environ.get('MOE', 0))
+ self.use_moe = (self.moe_freq > 0) and (cfg.transformer_index + 1) % self.moe_freq == 0
+
+ if self.use_moe:
+ assert self.quant_noise == 0, "Unhandled quant_noise > 0.0 for MoE layer."
+ from tutel.moe import moe_layer
+ self.moe_ffn = moe_layer(
+ gate_type={'type' : 'top', 'k' : 2, 'capacity_factor': 0.0, 'fp32_gate': True, 'gate_noise': 1.0},
+ model_dim=self.embed_dim,
+ experts={'count_per_node': 1,'type': 'ffn', 'hidden_size_per_expert': cfg.decoder.ffn_embed_dim,
+ 'activation_fn' : lambda x:
+ self.activation_dropout_module(x) if self.ffn_layernorm is None else self.ffn_layernorm(self.activation_dropout_module(x))},
+ scan_expert_func = lambda name, param: setattr(param, 'expert', True), # The mask is only compatible with Fairseq based on legacy_ddp
+ )
+ else:
+ self.fc1 = self.build_fc1(
+ self.embed_dim,
+ cfg.decoder.ffn_embed_dim,
+ self.quant_noise,
+ self.quant_noise_block_size,
+ )
+ self.fc2 = self.build_fc2(
+ cfg.decoder.ffn_embed_dim,
+ self.embed_dim,
+ self.quant_noise,
+ self.quant_noise_block_size,
+ )
self.final_layer_norm = LayerNorm(self.embed_dim, export=cfg.export)
self.need_attn = True
@@ -504,11 +519,18 @@ class TransformerDecoderLayerBase(nn.Module):
if self.normalize_before:
x = self.final_layer_norm(x)
- x = self.activation_fn(self.fc1(x))
- x = self.activation_dropout_module(x)
- if self.ffn_layernorm is not None:
- x = self.ffn_layernorm(x)
- x = self.fc2(x)
+ if self.use_moe:
+ x = self.moe_ffn(x)
+ from tutel import system
+ if x.l_aux.requires_grad:
+ system.cache().set(id(self.moe_ffn), (x.numel() // x.size(-1), x.l_aux))
+ else:
+ x = self.activation_fn(self.fc1(x))
+ x = self.activation_dropout_module(x)
+ if self.ffn_layernorm is not None:
+ x = self.ffn_layernorm(x)
+ x = self.fc2(x)
+
x = self.dropout_module(x)
if self.w_resid is not None:
residual = torch.mul(self.w_resid, residual)
diff --git a/fairseq/optim/fp16_optimizer.py b/fairseq/optim/fp16_optimizer.py
index 2c4ee32..15c264a 100644
--- a/fairseq/optim/fp16_optimizer.py
+++ b/fairseq/optim/fp16_optimizer.py
@@ -207,6 +207,11 @@ class _FP16OptimizerMixin(object):
def step(self, closure=None, groups=None):
"""Performs a single optimization step."""
+ if int(torch.os.environ.get('NO_OVERFLOW', 0)) > 0:
+ for x, y in zip(self.fp16_params, self.fp32_params):
+ x.grad[torch.isinf(x.grad)] = 0
+ y.grad[torch.isinf(y.grad)] = 0
+
self._sync_fp16_grads_to_fp32()
if getattr(self, "supports_step_with_scale", False):
diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py
index 273dbdd..4c8f06e 100644
--- a/fairseq/tasks/fairseq_task.py
+++ b/fairseq/tasks/fairseq_task.py
@@ -513,6 +513,16 @@ class FairseqTask(object):
if ignore_grad:
loss *= 0
with torch.autograd.profiler.record_function("backward"):
+ from tutel import system
+ l_aux_wt = float(torch.os.environ.get('L_AUX_WT', 0.0))
+ if l_aux_wt:
+ l_aux = None
+ for samples, x in system.cache().get():
+ x *= l_aux_wt * samples
+ l_aux = x if l_aux is None else l_aux + x
+ system.cache().reset()
+ if l_aux is not None:
+ loss += l_aux
optimizer.backward(loss)
return loss, sample_size, logging_output