Skip to content

Commit

Permalink
add implicit dropout layer if p > 0 (#21)
Browse files Browse the repository at this point in the history
  • Loading branch information
msftsw authored Oct 10, 2021
1 parent bd61492 commit b1468bd
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 5 deletions.
2 changes: 1 addition & 1 deletion examples/helloworld.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
args = parser.parse_args()

if args.local_rank < 0:
args.local_rank = int(os.environ['LOCAL_RANK'])
args.local_rank = int(os.environ.get('LOCAL_RANK', 0))

torch.cuda.set_device(args.local_rank)

Expand Down
2 changes: 1 addition & 1 deletion examples/helloworld_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
args = parser.parse_args()

if args.local_rank < 0:
args.local_rank = int(os.environ['LOCAL_RANK'])
args.local_rank = int(os.environ.get('LOCAL_RANK', 0))

torch.cuda.set_device(args.local_rank)

Expand Down
11 changes: 11 additions & 0 deletions tutel/impls/moe_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def __init__(self, gate_type, model_dim: int, experts = None, fp32_gate = False,
fused_custom_fn = experts.get('fused_custom_fn')
if fused_custom_fn is None:
activation_fn = experts.get('activation_fn', lambda x: F.relu(x))
implicit_dropout_p = experts.get('implicit_dropout_p', 0)

class FusedExpertsNetwork(torch.nn.Module):
def __init__(self, model_dim, hidden_size, local_experts):
Expand Down Expand Up @@ -184,6 +185,12 @@ def __init__(self, model_dim, hidden_size, local_experts):
self.register_parameter(name='fc1_bias', param=torch.nn.Parameter(fc1_bias))
self.register_parameter(name='fc2_bias', param=torch.nn.Parameter(fc2_bias))

if implicit_dropout_p:
self.dropout_fc1 = torch.nn.Dropout(p=implicit_dropout_p)
self.dropout_fc2 = torch.nn.Dropout(p=implicit_dropout_p)
else:
self.dropout_fc1 = self.dropout_fc2 = lambda x: x

def extra_repr(self):
return 'model_dim=%d, hidden_size=%d, local_experts=%d, bias=%s' % (self.model_dim, self.hidden_size, self.local_experts, self.fc1_bias is not None)

Expand All @@ -196,14 +203,18 @@ def forward(self, x):
original_shape, x = x.shape, x.view(-1, self.model_dim)
x = torch.addmm(self.fc1_bias, x, self.fc1_weight)
x = activation_fn(x)
x = self.dropout_fc1(x)
x = torch.addmm(self.fc2_bias, x, self.fc2_weight)
x = self.dropout_fc2(x)
x = x.view(original_shape)
else:
x = x.permute(1, 0, 2, 3)
original_shape, x = x.shape, x.reshape(self.local_experts, -1, self.model_dim)
x = torch.matmul(x, self.fc1_weight) + self.fc1_bias
x = activation_fn(x)
x = self.dropout_fc1(x)
x = torch.matmul(x, self.fc2_weight) + self.fc2_bias
x = self.dropout_fc2(x)
x = x.reshape(self.local_experts, original_shape[1], original_shape[2], self.model_dim)
x = x.permute(1, 0, 2, 3)
return x
Expand Down
6 changes: 3 additions & 3 deletions tutel/system_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
def init_affinity_at_program_beginning():
try:
numa_type = int(os.environ.get('NUMA_TYPE', '1'))
if numa_type == 0:
if numa_type <= 0:
return
group_rank = int(os.environ.get('LOCAL_RANK', '0'))
nodes = sorted([int(x[4:]) for x in os.listdir('/sys/devices/system/node') if re.match('node[0-9]+', x)])
cpus = [sorted([int(x[3:]) for x in os.listdir('/sys/devices/system/node/node%d' % node_id) if re.match('cpu[0-9]+', x)]) for node_id in nodes]
sel_node = group_rank % len(nodes)
sel_node = (group_rank // numa_type) % len(nodes)
os.sched_setaffinity(0, cpus[sel_node])
print('[INFO] LOCAL_RANK %d is to set NUMA node: %d' % (group_rank, sel_node))
print('[INFO] LOCAL_RANK %d is to set NUMA node: %d / %d' % (group_rank, sel_node, len(nodes)))
except Exception as ex:
if group_rank == 0:
print('[WARN] Failed to set NUMA status: %s' % ex)

0 comments on commit b1468bd

Please sign in to comment.