Skip to content

Commit

Permalink
address comments and add device argument
Browse files Browse the repository at this point in the history
  • Loading branch information
H-Huang committed Nov 30, 2023
1 parent 7f3ae62 commit 653f56a
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 48 deletions.
65 changes: 30 additions & 35 deletions pippy/PipelineSchedule.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

import argparse
import logging
import os

from collections import deque
from datetime import timedelta
from typing import List, Optional, Union
from typing import List

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from torch.profiler import profile, ProfilerActivity, record_function

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
handler.setLevel(logging.INFO)
logger.addHandler(handler)


class PipelineStage(nn.Module):
Expand All @@ -26,6 +25,7 @@ def __init__(
rank: int,
world_size: int,
meta_input: torch.Tensor,
device: torch.device,
):
super().__init__()
self.rank = rank
Expand All @@ -34,15 +34,12 @@ def __init__(
self.is_last_stage = stage_id == num_stages - 1
self.num_stages = num_stages
# When we materialize the model partition on cuda, we call reset_parameters() if it is available
self.module = module.to(device=torch.cuda.current_device())
if hasattr(self.module, "reset_parameters"):
with torch.no_grad():
self.module.reset_parameters()
self.module = module.to(device)

meta_output = self.module(meta_input)
self.fwd_input = torch.empty_like(meta_input, device="cuda")
self.fwd_input = torch.empty_like(meta_input, device=device)
self.fwd_output = None
self.fwd_output_grads = torch.empty_like(meta_output, device="cuda")
self.fwd_output_grads = torch.empty_like(meta_output, device=device)
self.fwd_outputs_for_backward = deque()

self.prev_stage = (rank - 1) % world_size
Expand All @@ -53,7 +50,8 @@ def __init__(

self.requests: List[dist.P2POp] = []
logger.info(
f"finished pipeline stage init, {self.stage_id=}, {self.is_first_stage=}, {self.is_last_stage=}, {self.num_stages=}, {self.fwd_input.shape=}, {self.fwd_output_grads.shape=}"
f"finished pipeline stage init, {self.stage_id=}, {self.is_first_stage=}, \
{self.is_last_stage=}, {self.num_stages=}, {self.fwd_input.shape=}, {self.fwd_output_grads.shape=}"
)

def init_p2p_neighbors(self):
Expand Down Expand Up @@ -109,9 +107,7 @@ def forward(self, input_data, is_first_mb, is_last_mb):
)

# we store a ref to the input/output pair for this forward to be later used by the corresponding backward
self.fwd_outputs_for_backward.append(
(self.fwd_input, output_for_backward)
)
self.fwd_outputs_for_backward.append((self.fwd_input, output_for_backward))

return self.fwd_output

Expand All @@ -121,7 +117,7 @@ def get_bwd_send_ops(self) -> List[dist.P2POp]:
assert self.fwd_input.grad is not None, "grad must be valid"
return [dist.P2POp(dist.isend, self.fwd_input.grad, self.prev_stage)]

def get_bwd_recv_ops(self) -> Optional[dist.P2POp]:
def get_bwd_recv_ops(self) -> List[dist.P2POp]:
if self.is_last_stage:
return []
return [dist.P2POp(dist.irecv, self.fwd_output_grads, self.next_stage)]
Expand Down Expand Up @@ -167,7 +163,7 @@ def compute_loss(self):


class PipelineScheduleGPipe:
def __init__(self, stage):
def __init__(self, stage: PipelineStage):
self._stage = stage

def step(self, microbatches):
Expand All @@ -179,9 +175,7 @@ def step(self, microbatches):
if ops:
dist.batch_isend_irecv(ops).pop().wait()

self._stage.forward(
mb, is_first_mb=i == 0, is_last_mb=is_last_mb
)
self._stage.forward(mb, is_first_mb=i == 0, is_last_mb=is_last_mb)

ops = self._stage.get_fwd_send_ops()
if ops:
Expand Down Expand Up @@ -210,7 +204,7 @@ def step(self, microbatches):


class PipelineScheduleLoopedBFS:
def __init__(self, stages):
def __init__(self, stages: List[PipelineStage]):
self._stages = stages

def step(self, microbatches):
Expand Down Expand Up @@ -247,27 +241,27 @@ def step(self, microbatches):


class PipelineScheduleLoopedDFS:
def __init__(self, stages, n_microbatch, pp_id, n_pp):
def __init__(self, stages: List[PipelineStage], n_microbatch, pp_id, n_pp):
assert (
n_microbatch % n_pp == 0
), f"Looped DFS schedule requires microbatch_size ({n_microbatch}) to be a multiple of n_pp ({n_pp})"

self.stages = stages
self.n_microbatch = n_microbatch

self.n_stages = len(stages)
self.total_stages = self.n_stages * n_pp
self.n_local_stages = len(stages)
self.total_stages = self.n_local_stages * n_pp
# world_size
self.n_pp = n_pp

self.stage_id_to_global_stage_id = [
(i * n_pp) + pp_id for i in range(self.n_stages)
(i * n_pp) + pp_id for i in range(self.n_local_stages)
]

# pp_id is the same as local rank within the PP dimension
self.pp_id = pp_id

# number of sequences (chunks) to divide microbatches into == microbatch_size / (microbatch_size / n_pp)
# number of sequences (chunks)
self.seq_size = n_pp

# warmup steps for latest pp stage is trivial to compute
Expand Down Expand Up @@ -311,13 +305,13 @@ def minibatch_index(step):
# step: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
# index:0, 1, 0, 1, 2, 3, 2, 3, 4, 5, 4, 5, 6, 7, 6, 7
return (step % self.seq_size) + self.seq_size * int(
step / (self.seq_size * self.n_stages)
step / (self.seq_size * self.n_local_stages)
)

def stage_index(step):
# step: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
# index:0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1
return int((step / self.seq_size) % self.n_stages)
return int((step / self.seq_size) % self.n_local_stages)

"""
Expand Down Expand Up @@ -403,7 +397,8 @@ def stage_index(step):

if requests:
logger.info(
f"rank: {self.pp_id}, current stage_id {self.stage_id_to_global_stage_id[fwd_stage_id]}, next stage_id {self.stage_id_to_global_stage_id[fwd_stage_id]} requests - {[(req.op, req.peer) for req in requests]}"
f"rank: {self.pp_id}, current stage_id {self.stage_id_to_global_stage_id[fwd_stage_id]}, \
next stage_id {self.stage_id_to_global_stage_id[fwd_stage_id]} requests - {[(req.op, req.peer) for req in requests]}"
)
forward_batched_op_handles.append(
dist.batch_isend_irecv(requests).pop()
Expand All @@ -426,11 +421,10 @@ def stage_index(step):
)
backward_batched_op_handles.pop().wait()

with record_function(
f"Stage {backward_stage.stage_id} Backward"
):
with record_function(f"Stage {backward_stage.stage_id} Backward"):
logger.info(
f"pp_id {self.pp_id} step {step}/{self.total_steps} backward_step {backward_step} backward_stage_id {backward_stage.stage_id} mb_id {mb_id_bwd}"
f"pp_id {self.pp_id} step {step}/{self.total_steps} backward_step {backward_step} \
backward_stage_id {backward_stage.stage_id} mb_id {mb_id_bwd}"
)
backward_stage.backward(
is_first_mb=mb_id_bwd == 0,
Expand All @@ -451,7 +445,8 @@ def stage_index(step):

if requests:
logger.info(
f"rank: {self.pp_id}, current stage_id {self.stage_id_to_global_stage_id[bwd_stage_id]}, next stage_id {self.stage_id_to_global_stage_id[bwd_stage_id_next]} requests - {[(req.op, req.peer) for req in requests]}"
f"rank: {self.pp_id}, current stage_id {self.stage_id_to_global_stage_id[bwd_stage_id]}, \
next stage_id {self.stage_id_to_global_stage_id[bwd_stage_id_next]} requests - {[(req.op, req.peer) for req in requests]}"
)
backward_batched_op_handles.append(
dist.batch_isend_irecv(requests).pop()
Expand Down
29 changes: 16 additions & 13 deletions test/test_pipeline_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@
import logging
import os

from collections import deque
from datetime import timedelta
from typing import List, Optional, Union

import torch
import torch.distributed as dist
Expand All @@ -36,7 +34,6 @@
PipelineScheduleLoopedDFS,
PipelineStage,
)
from torch.profiler import profile, ProfilerActivity, record_function

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -73,7 +70,7 @@ def setup(local_rank, world_size):

# initialize the process group
logger.info(f"init for rank {local_rank}")
dist.init_process_group("nccl")
dist.init_process_group("nccl", timeout=timedelta(seconds=20))
if torch.distributed.is_initialized():
torch.cuda.set_device(local_rank)

Expand All @@ -86,10 +83,11 @@ def main(**kwargs):
rank = kwargs["rank"]
local_rank = kwargs["local_rank"]
world_size = kwargs["world_size"]
device = torch.device(kwargs["device"])

setup(local_rank, world_size)
logger.info(
f"====== World Rank {rank}, Local Rank {local_rank}, World Size {world_size} main ======"
f"====== World Rank {rank}, Local Rank {local_rank}, World Size {world_size}, Device {device} main ======"
)

input_dim = 4000
Expand All @@ -110,7 +108,7 @@ def main(**kwargs):
x = torch.randn([microbatch_size, input_dim]).to("meta")

stage_model = PipelineStage(
module_list[rank], rank, world_size, rank, world_size, x
module_list[rank], rank, world_size, rank, world_size, x, device
)
stage_model.init_p2p_neighbors()

Expand All @@ -122,6 +120,7 @@ def main(**kwargs):
rank=rank,
world_size=world_size,
meta_input=x,
device=device,
)
for i in range(world_size)
]
Expand Down Expand Up @@ -177,13 +176,16 @@ def set_up_logging(rank, log_level=logging.INFO):
handler = logging.StreamHandler()
handler.setLevel(log_level)

class FstringFormatter(logging.Formatter):
def format(self, record):
return f"[{rank}][{record.levelname}][{self.formatTime(record)}][{os.path.basename(__file__)}:{record.lineno}]: {record.getMessage()}"
# TODO: seeing double logging due to global logging setup in
# - fx/passes/utils/matcher_utils.py

formatter = FstringFormatter()
handler.setFormatter(formatter)
logger.addHandler(handler)
# class FstringFormatter(logging.Formatter):
# def format(self, record):
# return f"[{rank}][{record.levelname}][{self.formatTime(record)}][{os.path.basename(__file__)}:{record.lineno}]:{record.getMessage()}"

# formatter = FstringFormatter()
# handler.setFormatter(formatter)
# logger.addHandler(handler)


if __name__ == "__main__":
Expand All @@ -201,8 +203,9 @@ def format(self, record):
type=str,
nargs="+",
choices=["gpipe", "looped_bfs", "looped_dfs"],
default=["gpipe", "looped_bfs"],
default=["gpipe", "looped_bfs", "looped_dfs"],
)
parser.add_argument("--device", type=str, default="cuda")
args = parser.parse_args()
kwargs = vars(args)
print(kwargs)
Expand Down

0 comments on commit 653f56a

Please sign in to comment.