Skip to content

Commit

Permalink
Expose _debug_mask_minibatches for stable numerical testing
Browse files Browse the repository at this point in the history
  • Loading branch information
kwen2501 committed Nov 15, 2023
1 parent 15dfcd8 commit 4b65796
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 6 deletions.
16 changes: 10 additions & 6 deletions pippy/microbatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,14 @@
from pippy.IR import TrivialLossWrapper


"""
_debug_mask_minibatches specifies to send masked versions of the mini-batch
through instead of micro-batch slices--this can be used for more stable
numerical testing (see [A Note About Correctness Testing])
"""
_debug_mask_minibatches = False


class CustomReducer:
def __init__(self, init_value, reduce_fn):
self.init_value = init_value
Expand Down Expand Up @@ -48,7 +56,6 @@ def shard_dict_of_args(
args_dict,
args_chunk_spec,
num_chunks,
_debug_mask_minibatches: bool = False,
):
# Stage 1+2: flatten and shard/replicate

Expand Down Expand Up @@ -173,7 +180,6 @@ def split_args_kwargs_into_chunks(
chunks,
args_chunk_spec=None,
kwargs_chunk_spec=None,
_debug_mask_minibatches: bool = False,
):
# Given `args` and `kwargs`, we want to yield a set of `chunks` args and kwargs such that
# the constituent Tensor values have been sharded/replicated according to the `args_chunk_spec`
Expand Down Expand Up @@ -221,12 +227,11 @@ def split_args_kwargs_into_chunks(
dict(enumerate(args)),
dict(enumerate(args_chunk_spec)),
chunks,
_debug_mask_minibatches,
)
real_num_chunks = len(args_split_dict)

kwargs_split = shard_dict_of_args(
kwargs, kwargs_chunk_spec, real_num_chunks, _debug_mask_minibatches
kwargs, kwargs_chunk_spec, real_num_chunks,
)

if len(kwargs_split) < real_num_chunks:
Expand All @@ -238,7 +243,6 @@ def split_args_kwargs_into_chunks(
dict(enumerate(args)),
dict(enumerate(args_chunk_spec)),
real_num_chunks,
_debug_mask_minibatches,
)

if len(args_split_dict) != len(kwargs_split):
Expand All @@ -254,7 +258,7 @@ def split_args_kwargs_into_chunks(
return args_split, kwargs_split


def merge_chunks(chunks, chunk_spec, _debug_mask_minibatches: bool = False):
def merge_chunks(chunks, chunk_spec):
# Given a list of chunks and a chunk specification, merge the chunks
# into a single value according to that chunk spec. This is essentially
# the inverse of `split_args_kwargs_into_chunks`, so the steps are
Expand Down
5 changes: 5 additions & 0 deletions test/local_test_c10d.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,15 @@
import torch
import torch.distributed as dist

import pippy
from pippy.compile import compile_stage
from pippy.IR import pipe_split


# For stable numerical testing
pippy.microbatch._debug_mask_minibatches = True


d_hid = 512
chunk_size = 256

Expand Down

0 comments on commit 4b65796

Please sign in to comment.