-
Notifications
You must be signed in to change notification settings - Fork 9
/
comm.py
309 lines (259 loc) · 11.7 KB
/
comm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
import os
import logging
from utils.logging_utils import disable_logging
import torch
import math
import numpy as np
import torch.distributed as dist
import datetime as dt
from typing import Union
# dummy placeholders
_COMM_LIST = []
_COMM_NAMES = {}
_COMM_NAMES_META = []
# world comm
def get_size(comm_id: Union[str, int]) -> int:
"""Returns the size of a specified communicator."""
if isinstance(comm_id, int):
cid = comm_id
else:
cid = _COMM_NAMES[comm_id] if (comm_id in _COMM_NAMES) else len(_COMM_LIST)
if not dist.is_initialized() or (cid >= len(_COMM_LIST)):
return 1
else:
return dist.get_world_size(group=_COMM_LIST[cid])
def get_rank(comm_id: Union[str, int]) -> int:
"""Returns the rank of a specified communicator."""
if isinstance(comm_id, int):
cid = comm_id
else:
cid = _COMM_NAMES[comm_id] if (comm_id in _COMM_NAMES) else len(_COMM_LIST)
if not dist.is_initialized() or (cid >= len(_COMM_LIST)):
return 0
else:
return dist.get_rank(group=_COMM_LIST[cid])
def get_group(comm_id: Union[str, int]) -> int:
"""Returns the group of a specified communicator."""
if isinstance(comm_id, int):
cid = comm_id
else:
cid = _COMM_NAMES[comm_id] if (comm_id in _COMM_NAMES) else len(_COMM_LIST)
if not dist.is_initialized() or (cid >= len(_COMM_LIST)):
raise IndexError(f"Error, comm with id {comm_id} not available.")
else:
return _COMM_LIST[cid]
# specialized routines for world comms
def get_world_size():
"""Returns the world size"""
if not dist.is_initialized():
return 1
else:
return dist.get_world_size()
def get_world_rank():
"""Returns the world rank"""
if not dist.is_initialized():
return 0
else:
return dist.get_rank()
def get_local_rank():
"""Returns the local rank of the current process."""
if os.getenv("LOCAL_RANK") is not None and False:
# Use PyTorch env var if available
return int(os.getenv("LOCAL_RANK"))
if not dist.is_initialized():
return 0
else:
num_gpu = int(os.getenv("NGPU_PER_NODE", torch.cuda.device_count()))
return get_world_rank() % num_gpu
def get_names(meta=True):
"""Returns the names of all available communicators."""
if meta:
return _COMM_NAMES
else:
return [c for c,v in _COMM_NAMES.items() if c not in _COMM_NAMES_META]
def is_distributed(name: str):
"""check if distributed."""
return name in _COMM_NAMES
def init(params, verbose = False):
init_process_group(info=params.wireup_info, store=params.wireup_store)
# do individual wireup for model parallel comms:
model_parallel_sizes = params.get("model_parallel_sizes", [1])
model_parallel_names = params.get("model_parallel_names", ["model"])
params.model_parallel_size = init_model_parallel_info(
names=model_parallel_names,
sizes=model_parallel_sizes,
verbose=verbose
)
def init_process_group(info: str, store: str):
"""Initial torch distributed process group based on ``info`` and ``store``
Uses NCCL
Args:
info: either ``env`` or ``mpi``
store: either ``file`` or ``tcp``
"""
# set up global and local communicator
if info == "env":
world_size = int(os.getenv('WORLD_SIZE', 1))
world_rank = int(os.getenv('RANK', 0))
if os.getenv('WORLD_RANK') is not None:
# Use WORLD_RANK if available for backwards compatibility
world_rank = int(os.getenv('WORLD_RANK'))
port = int(os.getenv('MASTER_PORT', 0))
master_address = os.getenv('MASTER_ADDR')
if os.getenv('MASTER_ADDRESS') is not None:
# Use MASTER_ADDRESS if available for backwards compatibility
master_address = os.getenv('MASTER_ADDRESS')
elif info == "mpi":
import socket
from mpi4py import MPI
mpi_comm = MPI.COMM_WORLD.Dup()
world_size = mpi_comm.Get_size()
world_rank = mpi_comm.Get_rank()
my_host = socket.gethostname()
port = 29500
master_address = None
if world_rank == 0:
master_address_info = socket.getaddrinfo(my_host, port, family=socket.AF_INET, proto=socket.IPPROTO_TCP)
master_address = master_address_info[0][-1][0]
master_address = mpi_comm.bcast(master_address, root=0)
os.environ["MASTER_ADDRESS"] = master_address
os.environ["MASTER_PORT"] = str(port)
else:
raise ValueError(f"Error, wireup-info {info} not supported")
# set local rank to 0 if env var not available
local_rank = int(os.getenv('LOCAL_RANK', 0))
if world_size > 1:
with disable_logging():
if store == "file":
wireup_file_path = os.getenv('WIREUP_FILE_PATH')
store = dist.FileStore(wireup_file_path, world_size)
elif store == "tcp":
# create tcp store
store = dist.TCPStore(host_name = master_address,
port = port,
world_size = world_size,
is_master = (world_rank == 0),
timeout = dt.timedelta(seconds=900))
else:
store = None
# initialize process groups
dist.init_process_group(backend = 'nccl',
rank = world_rank,
world_size = world_size,
store = store)
def init_model_parallel_info(names, sizes, verbose=False):
"""Create communicators for model parallelism _COMM_LIST, _COMM_NAMES"""
world_size = get_world_size()
world_rank = get_world_rank()
local_rank = get_local_rank()
model_parallel_names = names
model_parallel_sizes = sizes
assert(len(model_parallel_names) == len(model_parallel_sizes)), "Please specify names for your communicators"
model_parallel_size = math.prod(model_parallel_sizes)
assert ( (world_size % model_parallel_size == 0) ), \
"Error, please make sure that the product of model parallel ranks evenly divides the total number of ranks"
# we set this to be orthogonal to the MP groups
# we can play tricks with the ddp_group later, in case if all the weights are shared
data_parallel_size = world_size // model_parallel_size
# create orthogonal communicators first
global _COMM_LIST
global _COMM_NAMES
if world_size > 1:
# set up the strides:
model_parallel_sizes_reversed = model_parallel_sizes[::-1]
model_grid = np.reshape(np.arange(0, model_parallel_size), model_parallel_sizes[::-1])
perm = np.roll(np.arange(0,len(model_parallel_sizes)), 1).tolist()
ranks_lookup = {}
comm_count = 0
for mpname in model_parallel_names:
base_group = np.reshape(model_grid, (-1, model_grid.shape[-1]))
model_groups = []
for goffset in range(0, world_size, model_parallel_size):
model_groups += sorted((goffset + base_group).tolist())
if verbose and world_rank == 0:
print(f"Creating comm groups for id {mpname}: {model_groups}")
for grp in model_groups:
if len(grp) > 1:
tmp_group = dist.new_group(ranks = grp)
if world_rank in grp:
_COMM_LIST.append(tmp_group)
_COMM_NAMES[mpname] = comm_count
comm_count += 1
ranks_lookup[mpname] = model_groups
# go for the next step
model_grid = np.transpose(model_grid, perm)
# helper routine for creating meta comms
def merge_comms(comm_count, ranks_lookup, comm_name_1, comm_name_2, merge_name):
if ((get_size(comm_name_1) == 1) and (get_size(comm_name_2) > 1)):
if verbose and world_rank == 0:
print(f'Creating comm groups for id {merge_name}: {ranks_lookup[comm_name_2]}')
_COMM_LIST.append(get_group(comm_name_2))
_COMM_NAMES[merge_name] = comm_count
_COMM_NAMES_META.append(merge_name)
comm_count += 1
elif ((get_size(comm_name_1) > 1) and (get_size(comm_name_2) == 1)):
if verbose and world_rank == 0:
print(f'Creating comm groups for id {merge_name}: {ranks_lookup[comm_name_1]}')
_COMM_LIST.append(get_group(comm_name_1))
_COMM_NAMES[merge_name] = comm_count
_COMM_NAMES_META.append(merge_name)
comm_count += 1
elif ((get_size(comm_name_1) > 1) and (get_size(comm_name_2) > 1)):
# fuse the lists:
def merge_ranks(list1, list2):
coll = list1 + list2
pooled = [set(subList) for subList in coll]
merging = True
while merging:
merging=False
for i,group in enumerate(pooled):
merged = next((g for g in pooled[i+1:] if g.intersection(group)),None)
if not merged: continue
group.update(merged)
pooled.remove(merged)
merging = True
return [list(x) for x in pooled]
model_groups = merge_ranks(ranks_lookup[comm_name_1], ranks_lookup[comm_name_2])
if verbose and world_rank == 0:
print(f'Creating comm groups for id {merge_name}: {model_groups}')
for grp in model_groups:
tmp_group = dist.new_group(ranks = grp)
if world_rank in grp:
_COMM_LIST.append(tmp_group)
_COMM_NAMES[merge_name] = comm_count
_COMM_NAMES_META.append(merge_name)
comm_count += 1
return comm_count
# # no spatial for now: merge spatial
# comm_count = merge_comms(comm_count, ranks_lookup, "h", "w", "spatial")
# merge matmul
comm_count = merge_comms(comm_count, ranks_lookup, "row_matmul", "col_matmul", "matmul")
# now the data and model comm:
model_groups = np.reshape(np.arange(0, world_size), (-1, model_parallel_size)).tolist()
for grp in model_groups:
if len(grp) > 1:
tmp_group = dist.new_group(ranks = grp)
if world_rank in grp:
_COMM_LIST.append(tmp_group)
_COMM_NAMES["model"] = comm_count
_COMM_NAMES_META.append("model")
comm_count += 1
if data_parallel_size == world_size:
if verbose and world_rank == 0:
print(f"Creating comm groups for id data: {[list(range(0, world_size))]}")
_COMM_LIST.append(None)
_COMM_NAMES["data"] = comm_count
else:
data_groups = [sorted(list(i)) for i in zip(*model_groups)]
if verbose and world_rank == 0:
print(f"Creating comm groups for id data: {data_groups}")
for grp in data_groups:
tmp_group = dist.new_group(ranks = grp)
if world_rank in grp:
_COMM_LIST.append(tmp_group)
_COMM_NAMES["data"] = comm_count
_COMM_NAMES_META.append("data")
# if verbose and world_rank == 0:
# print(f"comm lists are: {_COMM_LIST}")
# print(f"comm names are: {_COMM_NAMES}")
return model_parallel_size