Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bring old seed technique back #2319

Merged
merged 5 commits into from
Jan 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,11 @@ class Accelerator:
If set to `True`, in cases where the total batch size across all processes does not exactly divide the
dataset, samples at the start of the dataset will be duplicated so the batch can be divided equally among
all workers.
use_seedable_sampler (`bool`, *optional*, defaults to `False`):
Whether or not use a fully seedable random sampler ([`~data_loader.SeedableRandomSampler`]). Ensures
training results are fully reproducable using a different sampling technique. While seed-to-seed results
may differ, on average the differences are neglible when using multiple different seeds to compare. Should
also be ran with [`~utils.set_seed`] for the best results.
step_scheduler_with_optimizer (`bool`, *optional`, defaults to `True`):
Set `True` if the learning rate scheduler is stepped at the same time as the optimizer, `False` if only
done under certain circumstances (at the end of each epoch, for instance).
Expand Down Expand Up @@ -262,6 +267,7 @@ def __init__(
gradient_accumulation_plugin: GradientAccumulationPlugin | None = None,
dispatch_batches: bool | None = None,
even_batches: bool = True,
use_seedable_sampler: bool = False,
step_scheduler_with_optimizer: bool = True,
kwargs_handlers: list[KwargsHandler] | None = None,
dynamo_backend: DynamoBackend | str | None = None,
Expand Down Expand Up @@ -417,6 +423,7 @@ def __init__(
self.split_batches = split_batches
self.dispatch_batches = dispatch_batches
self.even_batches = even_batches
self.use_seedable_sampler = use_seedable_sampler
self.step_scheduler_with_optimizer = step_scheduler_with_optimizer

# Mixed precision attributes
Expand Down Expand Up @@ -1857,6 +1864,7 @@ def prepare_data_loader(
dispatch_batches=self.dispatch_batches,
even_batches=self.even_batches,
slice_fn_for_dispatch=slice_fn_for_dispatch,
use_seedable_sampler=self.use_seedable_sampler,
)
self._dataloaders.append(prepared_data_loader)
return prepared_data_loader
Expand Down
9 changes: 7 additions & 2 deletions src/accelerate/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,6 +744,7 @@ def prepare_data_loader(
dispatch_batches: Optional[bool] = None,
even_batches: bool = True,
slice_fn_for_dispatch: Optional[Callable] = None,
use_seedable_sampler: bool = False,
) -> DataLoader:
"""
Wraps a PyTorch `DataLoader` to generate batches for one of the processes only.
Expand Down Expand Up @@ -797,6 +798,10 @@ def prepare_data_loader(
If passed, this function will be used to slice tensors across `num_processes`. Will default to
[`~utils.slice_tensors`]. This argument is used only when `dispatch_batches` is set to `True` and will be
ignored otherwise.
use_seedable_sampler (`bool`, *optional*, defaults to `False`):
Whether to use the [`~data_loader.SeedableRandomSampler`] instead of a `RandomSampler` for better
reproducability. Comes at a cost of potentially different performances due to different shuffling
algorithms but ensures results will be the *exact* same.

Returns:
`torch.utils.data.dataloader.DataLoader`: A new data loader that will yield the portion of the batches
Expand Down Expand Up @@ -840,7 +845,7 @@ def prepare_data_loader(
sampler = getattr(dataloader.sampler, "sampler", None)
else:
sampler = getattr(dataloader.batch_sampler, "sampler", None)
if isinstance(sampler, RandomSampler):
if isinstance(sampler, RandomSampler) and use_seedable_sampler:
# When iterating through the dataloader during distributed processes
# we want to ensure that on each process we are iterating through the same
# samples in the same order if a seed is set. This requires a tweak
Expand Down Expand Up @@ -899,7 +904,7 @@ def prepare_data_loader(
kwargs["batch_size"] = (
dataloader.batch_size // num_processes if split_batches and not dispatch_batches else dataloader.batch_size
)
if isinstance(sampler, SeedableRandomSampler):
if isinstance(sampler, SeedableRandomSampler) and use_seedable_sampler:
if sampler_is_batch_sampler:
dataloader.sampler.sampler = sampler
else:
Expand Down
50 changes: 32 additions & 18 deletions src/accelerate/test_utils/scripts/test_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,19 +335,22 @@ def __len__(self):
), "Custom sampler was changed after calling `prepare_data_loader`"


def mock_training(length, batch_size, generator):
def mock_training(length, batch_size, generator, use_seedable_sampler=False):
set_seed(42)
generator.manual_seed(42)
train_set = RegressionDataset(length=length, seed=42)

# The SeedableRandomSampler is needed during distributed setups
# for full reproducability across processes with the `DataLoader`
sampler = SeedableRandomSampler(
generator=generator,
data_source=train_set,
num_samples=len(train_set),
)
train_dl = DataLoader(train_set, batch_size=batch_size, sampler=sampler)
if use_seedable_sampler:
# The SeedableRandomSampler is needed during distributed setups
# for full reproducability across processes with the `DataLoader`
sampler = SeedableRandomSampler(
generator=generator,
data_source=train_set,
num_samples=len(train_set),
)
train_dl = DataLoader(train_set, batch_size=batch_size, sampler=sampler)
else:
train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator)
model = RegressionModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
for epoch in range(3):
Expand All @@ -360,18 +363,28 @@ def mock_training(length, batch_size, generator):
return train_set, model


def training_check():
def training_check(use_seedable_sampler=False):
state = AcceleratorState()
generator = torch.Generator()
batch_size = 8
length = batch_size * 4 * state.num_processes

train_set, old_model = mock_training(length, batch_size * state.num_processes, generator)
train_set, old_model = mock_training(length, batch_size * state.num_processes, generator, use_seedable_sampler)
assert are_the_same_tensors(old_model.a), "Did not obtain the same model on both processes."
assert are_the_same_tensors(old_model.b), "Did not obtain the same model on both processes."

accelerator = Accelerator()
train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator)
if use_seedable_sampler:
# The SeedableRandomSampler is needed during distributed setups
# for full reproducability across processes with the `DataLoader`
sampler = SeedableRandomSampler(
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
generator=generator,
data_source=train_set,
num_samples=len(train_set),
)
train_dl = DataLoader(train_set, batch_size=batch_size, sampler=sampler)
else:
train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator)
model = RegressionModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

Expand All @@ -392,7 +405,7 @@ def training_check():

accelerator.print("Training yielded the same results on one CPU or distributed setup with no batch split.")

accelerator = Accelerator(split_batches=True)
accelerator = Accelerator(split_batches=True, use_seedable_sampler=use_seedable_sampler)
train_dl = DataLoader(train_set, batch_size=batch_size * state.num_processes, shuffle=True, generator=generator)
model = RegressionModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
Expand All @@ -418,7 +431,7 @@ def training_check():
# Mostly a test that FP16 doesn't crash as the operation inside the model is not converted to FP16
print("FP16 training check.")
AcceleratorState._reset_state()
accelerator = Accelerator(mixed_precision="fp16")
accelerator = Accelerator(mixed_precision="fp16", use_seedable_sampler=use_seedable_sampler)
train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator)
model = RegressionModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
Expand Down Expand Up @@ -458,7 +471,7 @@ def training_check():
# Mostly a test that BF16 doesn't crash as the operation inside the model is not converted to BF16
print("BF16 training check.")
AcceleratorState._reset_state()
accelerator = Accelerator(mixed_precision="bf16")
accelerator = Accelerator(mixed_precision="bf16", use_seedable_sampler=use_seedable_sampler)
train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator)
model = RegressionModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
Expand All @@ -482,7 +495,7 @@ def training_check():
if is_ipex_available():
print("ipex BF16 training check.")
AcceleratorState._reset_state()
accelerator = Accelerator(mixed_precision="bf16", cpu=True)
accelerator = Accelerator(mixed_precision="bf16", cpu=True, use_seedable_sampler=use_seedable_sampler)
train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator)
model = RegressionModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
Expand All @@ -506,7 +519,7 @@ def training_check():
if is_xpu_available():
print("xpu BF16 training check.")
AcceleratorState._reset_state()
accelerator = Accelerator(mixed_precision="bf16", cpu=False)
accelerator = Accelerator(mixed_precision="bf16", cpu=False, use_seedable_sampler=use_seedable_sampler)
train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator)
model = RegressionModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
Expand Down Expand Up @@ -661,7 +674,8 @@ def main():

if state.local_process_index == 0:
print("\n**Training integration test**")
training_check()
training_check(use_seedable_sampler=False)
training_check(use_seedable_sampler=True)

if state.local_process_index == 0:
print("\n**Breakpoint trigger test**")
Expand Down
Loading