diff --git a/pippy/PipelineSchedule.py b/pippy/PipelineSchedule.py index 1df0d08e4..aac749ae1 100644 --- a/pippy/PipelineSchedule.py +++ b/pippy/PipelineSchedule.py @@ -1,20 +1,19 @@ # (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. -import argparse import logging -import os - from collections import deque -from datetime import timedelta -from typing import List, Optional, Union +from typing import List import torch import torch.distributed as dist -import torch.multiprocessing as mp import torch.nn as nn from torch.profiler import profile, ProfilerActivity, record_function logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) +handler = logging.StreamHandler() +handler.setLevel(logging.INFO) +logger.addHandler(handler) class PipelineStage(nn.Module): @@ -26,6 +25,7 @@ def __init__( rank: int, world_size: int, meta_input: torch.Tensor, + device: torch.device, ): super().__init__() self.rank = rank @@ -34,15 +34,12 @@ def __init__( self.is_last_stage = stage_id == num_stages - 1 self.num_stages = num_stages # When we materialize the model partition on cuda, we call reset_parameters() if it is available - self.module = module.to(device=torch.cuda.current_device()) - if hasattr(self.module, "reset_parameters"): - with torch.no_grad(): - self.module.reset_parameters() + self.module = module.to(device) meta_output = self.module(meta_input) - self.fwd_input = torch.empty_like(meta_input, device="cuda") + self.fwd_input = torch.empty_like(meta_input, device=device) self.fwd_output = None - self.fwd_output_grads = torch.empty_like(meta_output, device="cuda") + self.fwd_output_grads = torch.empty_like(meta_output, device=device) self.fwd_outputs_for_backward = deque() self.prev_stage = (rank - 1) % world_size @@ -53,7 +50,8 @@ def __init__( self.requests: List[dist.P2POp] = [] logger.info( - f"finished pipeline stage init, {self.stage_id=}, {self.is_first_stage=}, {self.is_last_stage=}, {self.num_stages=}, {self.fwd_input.shape=}, {self.fwd_output_grads.shape=}" + f"finished pipeline stage init, {self.stage_id=}, {self.is_first_stage=}, \ + {self.is_last_stage=}, {self.num_stages=}, {self.fwd_input.shape=}, {self.fwd_output_grads.shape=}" ) def init_p2p_neighbors(self): @@ -109,9 +107,7 @@ def forward(self, input_data, is_first_mb, is_last_mb): ) # we store a ref to the input/output pair for this forward to be later used by the corresponding backward - self.fwd_outputs_for_backward.append( - (self.fwd_input, output_for_backward) - ) + self.fwd_outputs_for_backward.append((self.fwd_input, output_for_backward)) return self.fwd_output @@ -121,7 +117,7 @@ def get_bwd_send_ops(self) -> List[dist.P2POp]: assert self.fwd_input.grad is not None, "grad must be valid" return [dist.P2POp(dist.isend, self.fwd_input.grad, self.prev_stage)] - def get_bwd_recv_ops(self) -> Optional[dist.P2POp]: + def get_bwd_recv_ops(self) -> List[dist.P2POp]: if self.is_last_stage: return [] return [dist.P2POp(dist.irecv, self.fwd_output_grads, self.next_stage)] @@ -167,7 +163,7 @@ def compute_loss(self): class PipelineScheduleGPipe: - def __init__(self, stage): + def __init__(self, stage: PipelineStage): self._stage = stage def step(self, microbatches): @@ -179,9 +175,7 @@ def step(self, microbatches): if ops: dist.batch_isend_irecv(ops).pop().wait() - self._stage.forward( - mb, is_first_mb=i == 0, is_last_mb=is_last_mb - ) + self._stage.forward(mb, is_first_mb=i == 0, is_last_mb=is_last_mb) ops = self._stage.get_fwd_send_ops() if ops: @@ -210,7 +204,7 @@ def step(self, microbatches): class PipelineScheduleLoopedBFS: - def __init__(self, stages): + def __init__(self, stages: List[PipelineStage]): self._stages = stages def step(self, microbatches): @@ -247,7 +241,7 @@ def step(self, microbatches): class PipelineScheduleLoopedDFS: - def __init__(self, stages, n_microbatch, pp_id, n_pp): + def __init__(self, stages: List[PipelineStage], n_microbatch, pp_id, n_pp): assert ( n_microbatch % n_pp == 0 ), f"Looped DFS schedule requires microbatch_size ({n_microbatch}) to be a multiple of n_pp ({n_pp})" @@ -255,19 +249,19 @@ def __init__(self, stages, n_microbatch, pp_id, n_pp): self.stages = stages self.n_microbatch = n_microbatch - self.n_stages = len(stages) - self.total_stages = self.n_stages * n_pp + self.n_local_stages = len(stages) + self.total_stages = self.n_local_stages * n_pp # world_size self.n_pp = n_pp self.stage_id_to_global_stage_id = [ - (i * n_pp) + pp_id for i in range(self.n_stages) + (i * n_pp) + pp_id for i in range(self.n_local_stages) ] # pp_id is the same as local rank within the PP dimension self.pp_id = pp_id - # number of sequences (chunks) to divide microbatches into == microbatch_size / (microbatch_size / n_pp) + # number of sequences (chunks) self.seq_size = n_pp # warmup steps for latest pp stage is trivial to compute @@ -311,13 +305,13 @@ def minibatch_index(step): # step: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 # index:0, 1, 0, 1, 2, 3, 2, 3, 4, 5, 4, 5, 6, 7, 6, 7 return (step % self.seq_size) + self.seq_size * int( - step / (self.seq_size * self.n_stages) + step / (self.seq_size * self.n_local_stages) ) def stage_index(step): # step: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 # index:0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1 - return int((step / self.seq_size) % self.n_stages) + return int((step / self.seq_size) % self.n_local_stages) """ @@ -403,7 +397,8 @@ def stage_index(step): if requests: logger.info( - f"rank: {self.pp_id}, current stage_id {self.stage_id_to_global_stage_id[fwd_stage_id]}, next stage_id {self.stage_id_to_global_stage_id[fwd_stage_id]} requests - {[(req.op, req.peer) for req in requests]}" + f"rank: {self.pp_id}, current stage_id {self.stage_id_to_global_stage_id[fwd_stage_id]}, \ + next stage_id {self.stage_id_to_global_stage_id[fwd_stage_id]} requests - {[(req.op, req.peer) for req in requests]}" ) forward_batched_op_handles.append( dist.batch_isend_irecv(requests).pop() @@ -426,11 +421,10 @@ def stage_index(step): ) backward_batched_op_handles.pop().wait() - with record_function( - f"Stage {backward_stage.stage_id} Backward" - ): + with record_function(f"Stage {backward_stage.stage_id} Backward"): logger.info( - f"pp_id {self.pp_id} step {step}/{self.total_steps} backward_step {backward_step} backward_stage_id {backward_stage.stage_id} mb_id {mb_id_bwd}" + f"pp_id {self.pp_id} step {step}/{self.total_steps} backward_step {backward_step} \ + backward_stage_id {backward_stage.stage_id} mb_id {mb_id_bwd}" ) backward_stage.backward( is_first_mb=mb_id_bwd == 0, @@ -451,7 +445,8 @@ def stage_index(step): if requests: logger.info( - f"rank: {self.pp_id}, current stage_id {self.stage_id_to_global_stage_id[bwd_stage_id]}, next stage_id {self.stage_id_to_global_stage_id[bwd_stage_id_next]} requests - {[(req.op, req.peer) for req in requests]}" + f"rank: {self.pp_id}, current stage_id {self.stage_id_to_global_stage_id[bwd_stage_id]}, \ + next stage_id {self.stage_id_to_global_stage_id[bwd_stage_id_next]} requests - {[(req.op, req.peer) for req in requests]}" ) backward_batched_op_handles.append( dist.batch_isend_irecv(requests).pop() diff --git a/test/test_pipeline_schedule.py b/test/test_pipeline_schedule.py index 561901a9a..9b18c0c07 100644 --- a/test/test_pipeline_schedule.py +++ b/test/test_pipeline_schedule.py @@ -22,9 +22,7 @@ import logging import os -from collections import deque from datetime import timedelta -from typing import List, Optional, Union import torch import torch.distributed as dist @@ -36,7 +34,6 @@ PipelineScheduleLoopedDFS, PipelineStage, ) -from torch.profiler import profile, ProfilerActivity, record_function logger = logging.getLogger(__name__) @@ -73,7 +70,7 @@ def setup(local_rank, world_size): # initialize the process group logger.info(f"init for rank {local_rank}") - dist.init_process_group("nccl") + dist.init_process_group("nccl", timeout=timedelta(seconds=20)) if torch.distributed.is_initialized(): torch.cuda.set_device(local_rank) @@ -86,10 +83,11 @@ def main(**kwargs): rank = kwargs["rank"] local_rank = kwargs["local_rank"] world_size = kwargs["world_size"] + device = torch.device(kwargs["device"]) setup(local_rank, world_size) logger.info( - f"====== World Rank {rank}, Local Rank {local_rank}, World Size {world_size} main ======" + f"====== World Rank {rank}, Local Rank {local_rank}, World Size {world_size}, Device {device} main ======" ) input_dim = 4000 @@ -110,7 +108,7 @@ def main(**kwargs): x = torch.randn([microbatch_size, input_dim]).to("meta") stage_model = PipelineStage( - module_list[rank], rank, world_size, rank, world_size, x + module_list[rank], rank, world_size, rank, world_size, x, device ) stage_model.init_p2p_neighbors() @@ -122,6 +120,7 @@ def main(**kwargs): rank=rank, world_size=world_size, meta_input=x, + device=device, ) for i in range(world_size) ] @@ -177,13 +176,16 @@ def set_up_logging(rank, log_level=logging.INFO): handler = logging.StreamHandler() handler.setLevel(log_level) - class FstringFormatter(logging.Formatter): - def format(self, record): - return f"[{rank}][{record.levelname}][{self.formatTime(record)}][{os.path.basename(__file__)}:{record.lineno}]: {record.getMessage()}" + # TODO: seeing double logging due to global logging setup in + # - fx/passes/utils/matcher_utils.py - formatter = FstringFormatter() - handler.setFormatter(formatter) - logger.addHandler(handler) + # class FstringFormatter(logging.Formatter): + # def format(self, record): + # return f"[{rank}][{record.levelname}][{self.formatTime(record)}][{os.path.basename(__file__)}:{record.lineno}]:{record.getMessage()}" + + # formatter = FstringFormatter() + # handler.setFormatter(formatter) + # logger.addHandler(handler) if __name__ == "__main__": @@ -201,8 +203,9 @@ def format(self, record): type=str, nargs="+", choices=["gpipe", "looped_bfs", "looped_dfs"], - default=["gpipe", "looped_bfs"], + default=["gpipe", "looped_bfs", "looped_dfs"], ) + parser.add_argument("--device", type=str, default="cuda") args = parser.parse_args() kwargs = vars(args) print(kwargs)