From 62b70f32fd643d9e443dc92a0e178e153faa1e0d Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Wed, 28 Aug 2024 08:34:53 -0700 Subject: [PATCH] Microbatch Device Movement (#3567) --- .../selective_backprop/selective_backprop.py | 2 + .../seq_length_warmup/seq_length_warmup.py | 2 + composer/core/data_spec.py | 2 +- composer/trainer/trainer.py | 4 +- tests/trainer/test_trainer.py | 43 +++++++++++++++++++ 5 files changed, 50 insertions(+), 3 deletions(-) diff --git a/composer/algorithms/selective_backprop/selective_backprop.py b/composer/algorithms/selective_backprop/selective_backprop.py index 88c89dfbb9..27d871cb7b 100644 --- a/composer/algorithms/selective_backprop/selective_backprop.py +++ b/composer/algorithms/selective_backprop/selective_backprop.py @@ -272,6 +272,8 @@ def apply(self, event: Event, state: State, logger: Optional[Logger] = None) -> raise RuntimeError('Model must be of type ComposerModel') self._loss_fn = state.model.loss return + + state.batch = state.device.batch_to_device(state.batch) input, target = state.batch_get_item(key=self.input_key), state.batch_get_item(key=self.target_key) assert isinstance(input, torch.Tensor) and isinstance(target, torch.Tensor), \ 'Multiple tensors not supported for this method yet.' diff --git a/composer/algorithms/seq_length_warmup/seq_length_warmup.py b/composer/algorithms/seq_length_warmup/seq_length_warmup.py index 869ed29ae8..2ab0eecee0 100644 --- a/composer/algorithms/seq_length_warmup/seq_length_warmup.py +++ b/composer/algorithms/seq_length_warmup/seq_length_warmup.py @@ -292,6 +292,8 @@ def _activate_model(self, state: State, logger: Logger) -> None: while True: model_inputs = {k: v[:state.device_train_microbatch_size] for k, v in batch_clone.items()} + model_inputs = state.device.batch_to_device(model_inputs) + found_cuda_oom = 0 # int since bool BOR not supported on all torch.distributed backends try: # Start by running a forward and backward pass diff --git a/composer/core/data_spec.py b/composer/core/data_spec.py index 73f111322d..35aa94f05e 100644 --- a/composer/core/data_spec.py +++ b/composer/core/data_spec.py @@ -258,7 +258,7 @@ def _default_get_num_samples_in_batch(self, batch: Batch) -> int: '`get_num_samples_in_batch(your_batch) -> int` method.', ) dim0_sizes.append(t.shape[0]) - elif isinstance(batch, dict): + elif isinstance(batch, Mapping): for t in batch.values(): if isinstance(t, torch.Tensor): dim0_sizes.append(t.shape[0]) diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 68d543e40e..815aa50001 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -2659,7 +2659,6 @@ def _train_loop(self) -> None: self._rng_state = None continue - self.state.batch = self.state.device.batch_to_device(self.state.batch) self.state.batch = self._train_data_spec.device_transforms(self.state.batch) rank_num_samples = self._train_data_spec.get_num_samples_in_batch(self.state.batch) rank_num_tokens = self._train_data_spec.get_num_tokens_in_batch(self.state.batch) @@ -3071,6 +3070,7 @@ def _train_microbatches( current_batch = self.state.batch for microbatch_idx, self.state.batch in enumerate(microbatches): + self.state.batch = self.state.device.batch_to_device(self.state.batch) is_final_microbatch = microbatch_idx + 1 == len(microbatches) microbatch_loss_dict = self._train_microbatch(use_grad_scaling, current_batch_size, is_final_microbatch) @@ -3619,7 +3619,6 @@ def _eval_loop( ) for self.state.batch in self._iter_dataloader(TrainerMode.EVAL): - self.state.batch = self.state.device.batch_to_device(self.state.batch) self.state.batch = data_spec.device_transforms(self.state.batch) # Count the batch size and num tokens before any events run @@ -3649,6 +3648,7 @@ def _eval_loop( try: microbatches = data_spec.split_batch(device_batch, evaluator.device_eval_microbatch_size) for i, self.state.batch in enumerate(microbatches): + self.state.batch = self.state.device.batch_to_device(self.state.batch) last_microbatch = i == len(microbatches) - 1 skip_metric_update = False # Distributed samplers pad batches to be the same size. If using a diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 68d3f44584..cf7abecd43 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -207,6 +207,49 @@ def test_eval_metrics(self): assert len(eval_metric_names) == 1 assert next(iter(eval_metric_names)) == single_metric + @pytest.mark.gpu + def test_memory_after_dataloader(self, model: ComposerModel): + + def track_memory_after_dataloader(global_batch_size): + + class MiniMemoryMonitor(Callback): + + def __init__(self): + self.batch_memory_usages = [] + + def epoch_start(self, state: State, logger: Logger) -> None: + current_alloc_memory = torch.cuda.memory_allocated() // 2**20 # Convert to MiB + self.batch_memory_usages.append(current_alloc_memory) + + def after_dataloader(self, state: State, logger: Logger): + current_alloc_memory = torch.cuda.memory_allocated() // 2**20 # Convert to MiB + self.batch_memory_usages.append(current_alloc_memory) + + microbatch_size = 1 + input_shape = (100000,) + dataset = RandomClassificationDataset(shape=input_shape, size=1024) + train_dataloader = DataLoader(dataset, batch_size=global_batch_size) + mini_memory_monitor = MiniMemoryMonitor() + + trainer = Trainer( + model=model, + train_dataloader=train_dataloader, + max_duration='1ba', + device='gpu', + device_train_microbatch_size=microbatch_size, + callbacks=[mini_memory_monitor], + ) + + trainer.fit() + return mini_memory_monitor.batch_memory_usages[1] - mini_memory_monitor.batch_memory_usages[0] + + global_batch_size = 32 + mem_change_epoch_start_and_after_dataloader = track_memory_after_dataloader(global_batch_size) + assert (mem_change_epoch_start_and_after_dataloader < 1), ( + f'Memory increased between epoch start and after dataloader by more than 1 MiB: {mem_change_epoch_start_and_after_dataloader} MiB. ' + f'None of the samples should be moved onto a GPU until the batch has already been divided into microbatches.' + ) + def _assert_optimizer_is_on_device(optimizer: torch.optim.Optimizer): for state in optimizer.state.values():