forked from facebookresearch/xformers
-
Notifications
You must be signed in to change notification settings - Fork 1
/
mixture_of_experts.py
153 lines (121 loc) · 5.22 KB
/
mixture_of_experts.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import logging
from dataclasses import dataclass
from enum import Enum
from typing import Any, Callable, Optional, Union
import torch
from xformers.components import Activation
from xformers.components.feedforward import (
Feedforward,
FeedforwardConfig,
register_feedforward,
)
logger = logging.getLogger("xformers")
_is_fairscale_available = True
try:
import torch.distributed as dist
from fairscale.nn import MOELayer, Top2Gate # type: ignore
from xformers.components.feedforward import MLP
except ImportError:
logger.warning(
"Either FairScale or torch distributed is not available, MixtureOfExperts will not be exposed."
" Please install them if you would like to use MoE"
)
_is_fairscale_available = False
if _is_fairscale_available:
# Credits: initially implemented in FairScale for sanity checking
class RoundRobinGate(torch.nn.Module):
def __init__(self, model_dim, num_experts):
super().__init__()
self.model_dim = model_dim
self.num_experts = num_experts
def forward(self, input):
s = input.shape[0]
assert s % self.num_experts == 0, f"{s} % {self.num_experts} != 0"
capacity = 2 * s // self.num_experts
output = torch.zeros(
s, self.num_experts, capacity, dtype=input.dtype, device=input.device
)
for i in range(s):
output[i, i % self.num_experts, i // self.num_experts] = 1.0
return 0.0, output, output.bool()
class GateConfig(str, Enum):
RoundRobin = "round_robin"
Top2 = "top_2"
# Other gating techniques could be exposed here
@dataclass
class MoEConfig(FeedforwardConfig):
number_of_experts: int
gate: GateConfig
number_of_local_experts: Optional[int] = None
expert_constructor: Optional[Any] = None
hidden_layer_multiplier: Optional[int] = None
group: Optional[Any] = None
@register_feedforward("MixtureOfExperts", MoEConfig)
class MixtureOfExperts(Feedforward):
"""
A MLP variant which uses the "Mixture of Experts" paradigm, as described in Gshard_.
xFormers uses the FairScale_ implementation under the hood.
.. warning: Please note that most of the benefits of MoE are present in a distributed training environmentt
.. _Gshard: https://arxiv.org/pdf/2006.16668.pdf
.. _FairScale: https://github.com/facebookresearch/fairscale/
"""
def __init__(
self,
dim_model: int,
dropout: float,
activation: Activation,
number_of_experts: int,
gate: Union[GateConfig, torch.nn.Module],
number_of_local_experts: Optional[int] = None,
expert_constructor: Optional[Callable[[], torch.nn.Module]] = None,
hidden_layer_multiplier: Optional[int] = None,
group: Optional[Any] = None,
*_,
**__,
):
super().__init__()
# Handle a possibly uninitialized process group
assert (
dist.is_initialized()
), "Mixture of Experts require torch distributed to be initialized"
if number_of_local_experts is not None:
assert number_of_experts >= number_of_local_experts
else:
if dist.get_world_size() == 1:
logger.warning("Local experts no specified but world size of 1")
logger.warning("Assuming that all experts are local")
number_of_local_experts = number_of_experts
else:
number_of_local_experts = 1
# Programatically handle the gating technique
if not isinstance(gate, torch.nn.Module):
gate_constructor = {
GateConfig.RoundRobin: RoundRobinGate,
GateConfig.Top2: Top2Gate,
}[gate]
self.gate = gate_constructor(dim_model, number_of_experts)
else:
self.gate = gate
# Programatically handle the experts
if expert_constructor is None:
multiplier = (
hidden_layer_multiplier
if hidden_layer_multiplier is not None
else 4
)
def expert_constructor() -> torch.nn.Module:
return MLP(dim_model, dropout, activation, multiplier)
assert expert_constructor is not None
local_experts = torch.nn.ModuleList(
[expert_constructor() for _ in range(number_of_local_experts)]
)
self.moe = MOELayer(gate=self.gate, experts=local_experts, group=group)
self.requires_cuda = True
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
# FairScale MoE assumes that the dimensions are [S, B, E]
# xFormers assumes [B, S, E]
return self.moe(inputs.movedim(0, 1)).movedim(0, 1)