From f8375627ffe01a2b8f28a349a01697641809d91e Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Tue, 9 Jan 2024 04:14:47 -0500 Subject: [PATCH 1/5] Redo stage 1 --- src/accelerate/accelerator.py | 7 ++++++ src/accelerate/data_loader.py | 9 +++++-- .../test_utils/scripts/test_script.py | 25 ++++++++++++------- 3 files changed, 30 insertions(+), 11 deletions(-) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 32e51f25983..a4cddd10213 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -217,6 +217,10 @@ class Accelerator: If set to `True`, in cases where the total batch size across all processes does not exactly divide the dataset, samples at the start of the dataset will be duplicated so the batch can be divided equally among all workers. + use_seedable_sampler (`bool`, *optional*, defaults to `False`): + Whether or not use a fully seedable random sampler ([`~data_loader.SeedableRandomSampler`]). Comes at a + cost of potentially different performances due to different shuffling algorithms, but will ensure the + training results are fully reproducible. step_scheduler_with_optimizer (`bool`, *optional`, defaults to `True`): Set `True` if the learning rate scheduler is stepped at the same time as the optimizer, `False` if only done under certain circumstances (at the end of each epoch, for instance). @@ -262,6 +266,7 @@ def __init__( gradient_accumulation_plugin: GradientAccumulationPlugin | None = None, dispatch_batches: bool | None = None, even_batches: bool = True, + use_seedable_sampler: bool = False, step_scheduler_with_optimizer: bool = True, kwargs_handlers: list[KwargsHandler] | None = None, dynamo_backend: DynamoBackend | str | None = None, @@ -417,6 +422,7 @@ def __init__( self.split_batches = split_batches self.dispatch_batches = dispatch_batches self.even_batches = even_batches + self.use_seedable_sampler = use_seedable_sampler self.step_scheduler_with_optimizer = step_scheduler_with_optimizer # Mixed precision attributes @@ -1857,6 +1863,7 @@ def prepare_data_loader( dispatch_batches=self.dispatch_batches, even_batches=self.even_batches, slice_fn_for_dispatch=slice_fn_for_dispatch, + use_seedable_sampler=self.use_seedable_sampler, ) self._dataloaders.append(prepared_data_loader) return prepared_data_loader diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index 01a348f82bd..1456454eee3 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -744,6 +744,7 @@ def prepare_data_loader( dispatch_batches: Optional[bool] = None, even_batches: bool = True, slice_fn_for_dispatch: Optional[Callable] = None, + use_seedable_sampler: bool = False, ) -> DataLoader: """ Wraps a PyTorch `DataLoader` to generate batches for one of the processes only. @@ -797,6 +798,10 @@ def prepare_data_loader( If passed, this function will be used to slice tensors across `num_processes`. Will default to [`~utils.slice_tensors`]. This argument is used only when `dispatch_batches` is set to `True` and will be ignored otherwise. + use_seedable_sampler (`bool`, *optional*, defaults to `False`): + Whether to use the [`~data_loader.SeedableRandomSampler`] instead of a `RandomSampler` for better + reproducability. Comes at a cost of potentially different performances due to different shuffling + algorithms but ensures results will be the *exact* same. Returns: `torch.utils.data.dataloader.DataLoader`: A new data loader that will yield the portion of the batches @@ -840,7 +845,7 @@ def prepare_data_loader( sampler = getattr(dataloader.sampler, "sampler", None) else: sampler = getattr(dataloader.batch_sampler, "sampler", None) - if isinstance(sampler, RandomSampler): + if isinstance(sampler, RandomSampler) and use_seedable_sampler: # When iterating through the dataloader during distributed processes # we want to ensure that on each process we are iterating through the same # samples in the same order if a seed is set. This requires a tweak @@ -899,7 +904,7 @@ def prepare_data_loader( kwargs["batch_size"] = ( dataloader.batch_size // num_processes if split_batches and not dispatch_batches else dataloader.batch_size ) - if isinstance(sampler, SeedableRandomSampler): + if isinstance(sampler, SeedableRandomSampler) and use_seedable_sampler: if sampler_is_batch_sampler: dataloader.sampler.sampler = sampler else: diff --git a/src/accelerate/test_utils/scripts/test_script.py b/src/accelerate/test_utils/scripts/test_script.py index 14269cb3b69..08428562f66 100644 --- a/src/accelerate/test_utils/scripts/test_script.py +++ b/src/accelerate/test_utils/scripts/test_script.py @@ -335,19 +335,22 @@ def __len__(self): ), "Custom sampler was changed after calling `prepare_data_loader`" -def mock_training(length, batch_size, generator): +def mock_training(length, batch_size, generator, use_seedable_sampler=False): set_seed(42) generator.manual_seed(42) train_set = RegressionDataset(length=length, seed=42) - # The SeedableRandomSampler is needed during distributed setups - # for full reproducability across processes with the `DataLoader` - sampler = SeedableRandomSampler( - generator=generator, - data_source=train_set, - num_samples=len(train_set), - ) - train_dl = DataLoader(train_set, batch_size=batch_size, sampler=sampler) + if use_seedable_sampler: + # The SeedableRandomSampler is needed during distributed setups + # for full reproducability across processes with the `DataLoader` + sampler = SeedableRandomSampler( + generator=generator, + data_source=train_set, + num_samples=len(train_set), + ) + train_dl = DataLoader(train_set, batch_size=batch_size, sampler=sampler) + else: + train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator) model = RegressionModel() optimizer = torch.optim.SGD(model.parameters(), lr=0.1) for epoch in range(3): @@ -370,6 +373,10 @@ def training_check(): assert are_the_same_tensors(old_model.a), "Did not obtain the same model on both processes." assert are_the_same_tensors(old_model.b), "Did not obtain the same model on both processes." + train_set, old_model = mock_training(length, batch_size * state.num_processes, generator, True) + assert are_the_same_tensors(old_model.a), "Did not obtain the same model on both processes." + assert are_the_same_tensors(old_model.b), "Did not obtain the same model on both processes." + accelerator = Accelerator() train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator) model = RegressionModel() From 76e71f20018111827d9dcfa2841ade2001f506c7 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Tue, 9 Jan 2024 04:24:27 -0500 Subject: [PATCH 2/5] Fix rest of tests --- .../test_utils/scripts/test_script.py | 33 +++++++++++-------- 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/src/accelerate/test_utils/scripts/test_script.py b/src/accelerate/test_utils/scripts/test_script.py index 08428562f66..7dfb6ec7530 100644 --- a/src/accelerate/test_utils/scripts/test_script.py +++ b/src/accelerate/test_utils/scripts/test_script.py @@ -363,22 +363,28 @@ def mock_training(length, batch_size, generator, use_seedable_sampler=False): return train_set, model -def training_check(): +def training_check(use_seedable_sampler=False): state = AcceleratorState() generator = torch.Generator() batch_size = 8 length = batch_size * 4 * state.num_processes - train_set, old_model = mock_training(length, batch_size * state.num_processes, generator) - assert are_the_same_tensors(old_model.a), "Did not obtain the same model on both processes." - assert are_the_same_tensors(old_model.b), "Did not obtain the same model on both processes." - - train_set, old_model = mock_training(length, batch_size * state.num_processes, generator, True) + train_set, old_model = mock_training(length, batch_size * state.num_processes, generator, use_seedable_sampler) assert are_the_same_tensors(old_model.a), "Did not obtain the same model on both processes." assert are_the_same_tensors(old_model.b), "Did not obtain the same model on both processes." accelerator = Accelerator() - train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator) + if use_seedable_sampler: + # The SeedableRandomSampler is needed during distributed setups + # for full reproducability across processes with the `DataLoader` + sampler = SeedableRandomSampler( + generator=generator, + data_source=train_set, + num_samples=len(train_set), + ) + train_dl = DataLoader(train_set, batch_size=batch_size, sampler=sampler) + else: + train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator) model = RegressionModel() optimizer = torch.optim.SGD(model.parameters(), lr=0.1) @@ -399,7 +405,7 @@ def training_check(): accelerator.print("Training yielded the same results on one CPU or distributed setup with no batch split.") - accelerator = Accelerator(split_batches=True) + accelerator = Accelerator(split_batches=True, use_seedable_sampler=use_seedable_sampler) train_dl = DataLoader(train_set, batch_size=batch_size * state.num_processes, shuffle=True, generator=generator) model = RegressionModel() optimizer = torch.optim.SGD(model.parameters(), lr=0.1) @@ -425,7 +431,7 @@ def training_check(): # Mostly a test that FP16 doesn't crash as the operation inside the model is not converted to FP16 print("FP16 training check.") AcceleratorState._reset_state() - accelerator = Accelerator(mixed_precision="fp16") + accelerator = Accelerator(mixed_precision="fp16", use_seedable_sampler=use_seedable_sampler) train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator) model = RegressionModel() optimizer = torch.optim.SGD(model.parameters(), lr=0.1) @@ -465,7 +471,7 @@ def training_check(): # Mostly a test that BF16 doesn't crash as the operation inside the model is not converted to BF16 print("BF16 training check.") AcceleratorState._reset_state() - accelerator = Accelerator(mixed_precision="bf16") + accelerator = Accelerator(mixed_precision="bf16", use_seedable_sampler=use_seedable_sampler) train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator) model = RegressionModel() optimizer = torch.optim.SGD(model.parameters(), lr=0.1) @@ -489,7 +495,7 @@ def training_check(): if is_ipex_available(): print("ipex BF16 training check.") AcceleratorState._reset_state() - accelerator = Accelerator(mixed_precision="bf16", cpu=True) + accelerator = Accelerator(mixed_precision="bf16", cpu=True, use_seedable_sampler=use_seedable_sampler) train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator) model = RegressionModel() optimizer = torch.optim.SGD(model.parameters(), lr=0.1) @@ -513,7 +519,7 @@ def training_check(): if is_xpu_available(): print("xpu BF16 training check.") AcceleratorState._reset_state() - accelerator = Accelerator(mixed_precision="bf16", cpu=False) + accelerator = Accelerator(mixed_precision="bf16", cpu=False, use_seedable_sampler=use_seedable_sampler) train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator) model = RegressionModel() optimizer = torch.optim.SGD(model.parameters(), lr=0.1) @@ -668,7 +674,8 @@ def main(): if state.local_process_index == 0: print("\n**Training integration test**") - training_check() + training_check(use_seedable_sampler=False) + training_check(use_seedable_sampler=True) if state.local_process_index == 0: print("\n**Breakpoint trigger test**") From ebb649537c4548e87783679b09dc1a03cb73520d Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Tue, 9 Jan 2024 12:04:54 -0500 Subject: [PATCH 3/5] Expand doc --- src/accelerate/accelerator.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index a4cddd10213..0d75b225479 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -218,9 +218,10 @@ class Accelerator: dataset, samples at the start of the dataset will be duplicated so the batch can be divided equally among all workers. use_seedable_sampler (`bool`, *optional*, defaults to `False`): - Whether or not use a fully seedable random sampler ([`~data_loader.SeedableRandomSampler`]). Comes at a - cost of potentially different performances due to different shuffling algorithms, but will ensure the - training results are fully reproducible. + Whether or not use a fully seedable random sampler ([`~data_loader.SeedableRandomSampler`]). Ensures + training results are fully reproducable using a different sampling technique. While seed-to-seed + results may differ, on average the differences are neglible when using multiple different seeds to + compare. step_scheduler_with_optimizer (`bool`, *optional`, defaults to `True`): Set `True` if the learning rate scheduler is stepped at the same time as the optimizer, `False` if only done under certain circumstances (at the end of each epoch, for instance). From 30610fae07b176ae6058ea41e25f93d2bab7e7b5 Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Tue, 9 Jan 2024 12:05:31 -0500 Subject: [PATCH 4/5] Expand x2 --- src/accelerate/accelerator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 0d75b225479..2dd7d9fd95e 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -221,7 +221,7 @@ class Accelerator: Whether or not use a fully seedable random sampler ([`~data_loader.SeedableRandomSampler`]). Ensures training results are fully reproducable using a different sampling technique. While seed-to-seed results may differ, on average the differences are neglible when using multiple different seeds to - compare. + compare. Should also be ran with [`~utils.set_seed`] for the best results. step_scheduler_with_optimizer (`bool`, *optional`, defaults to `True`): Set `True` if the learning rate scheduler is stepped at the same time as the optimizer, `False` if only done under certain circumstances (at the end of each epoch, for instance). From c7ae51a59352b9b21c53b0d28bf8d9ba521f02ba Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Tue, 9 Jan 2024 12:05:38 -0500 Subject: [PATCH 5/5] Expand x2 --- src/accelerate/accelerator.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 2dd7d9fd95e..c49b722f9fe 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -219,9 +219,9 @@ class Accelerator: all workers. use_seedable_sampler (`bool`, *optional*, defaults to `False`): Whether or not use a fully seedable random sampler ([`~data_loader.SeedableRandomSampler`]). Ensures - training results are fully reproducable using a different sampling technique. While seed-to-seed - results may differ, on average the differences are neglible when using multiple different seeds to - compare. Should also be ran with [`~utils.set_seed`] for the best results. + training results are fully reproducable using a different sampling technique. While seed-to-seed results + may differ, on average the differences are neglible when using multiple different seeds to compare. Should + also be ran with [`~utils.set_seed`] for the best results. step_scheduler_with_optimizer (`bool`, *optional`, defaults to `True`): Set `True` if the learning rate scheduler is stepped at the same time as the optimizer, `False` if only done under certain circumstances (at the end of each epoch, for instance).