Skip to content

Commit

Permalink
Microbatch Device Movement (mosaicml#3567)
Browse files Browse the repository at this point in the history
  • Loading branch information
mvpatel2000 authored Aug 28, 2024
1 parent 05e8c20 commit 62b70f3
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 3 deletions.
2 changes: 2 additions & 0 deletions composer/algorithms/selective_backprop/selective_backprop.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,8 @@ def apply(self, event: Event, state: State, logger: Optional[Logger] = None) ->
raise RuntimeError('Model must be of type ComposerModel')
self._loss_fn = state.model.loss
return

state.batch = state.device.batch_to_device(state.batch)
input, target = state.batch_get_item(key=self.input_key), state.batch_get_item(key=self.target_key)
assert isinstance(input, torch.Tensor) and isinstance(target, torch.Tensor), \
'Multiple tensors not supported for this method yet.'
Expand Down
2 changes: 2 additions & 0 deletions composer/algorithms/seq_length_warmup/seq_length_warmup.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,8 @@ def _activate_model(self, state: State, logger: Logger) -> None:
while True:
model_inputs = {k: v[:state.device_train_microbatch_size] for k, v in batch_clone.items()}

model_inputs = state.device.batch_to_device(model_inputs)

found_cuda_oom = 0 # int since bool BOR not supported on all torch.distributed backends
try:
# Start by running a forward and backward pass
Expand Down
2 changes: 1 addition & 1 deletion composer/core/data_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def _default_get_num_samples_in_batch(self, batch: Batch) -> int:
'`get_num_samples_in_batch(your_batch) -> int` method.',
)
dim0_sizes.append(t.shape[0])
elif isinstance(batch, dict):
elif isinstance(batch, Mapping):
for t in batch.values():
if isinstance(t, torch.Tensor):
dim0_sizes.append(t.shape[0])
Expand Down
4 changes: 2 additions & 2 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2659,7 +2659,6 @@ def _train_loop(self) -> None:
self._rng_state = None
continue

self.state.batch = self.state.device.batch_to_device(self.state.batch)
self.state.batch = self._train_data_spec.device_transforms(self.state.batch)
rank_num_samples = self._train_data_spec.get_num_samples_in_batch(self.state.batch)
rank_num_tokens = self._train_data_spec.get_num_tokens_in_batch(self.state.batch)
Expand Down Expand Up @@ -3071,6 +3070,7 @@ def _train_microbatches(
current_batch = self.state.batch

for microbatch_idx, self.state.batch in enumerate(microbatches):
self.state.batch = self.state.device.batch_to_device(self.state.batch)
is_final_microbatch = microbatch_idx + 1 == len(microbatches)
microbatch_loss_dict = self._train_microbatch(use_grad_scaling, current_batch_size, is_final_microbatch)

Expand Down Expand Up @@ -3619,7 +3619,6 @@ def _eval_loop(
)

for self.state.batch in self._iter_dataloader(TrainerMode.EVAL):
self.state.batch = self.state.device.batch_to_device(self.state.batch)
self.state.batch = data_spec.device_transforms(self.state.batch)

# Count the batch size and num tokens before any events run
Expand Down Expand Up @@ -3649,6 +3648,7 @@ def _eval_loop(
try:
microbatches = data_spec.split_batch(device_batch, evaluator.device_eval_microbatch_size)
for i, self.state.batch in enumerate(microbatches):
self.state.batch = self.state.device.batch_to_device(self.state.batch)
last_microbatch = i == len(microbatches) - 1
skip_metric_update = False
# Distributed samplers pad batches to be the same size. If using a
Expand Down
43 changes: 43 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,49 @@ def test_eval_metrics(self):
assert len(eval_metric_names) == 1
assert next(iter(eval_metric_names)) == single_metric

@pytest.mark.gpu
def test_memory_after_dataloader(self, model: ComposerModel):

def track_memory_after_dataloader(global_batch_size):

class MiniMemoryMonitor(Callback):

def __init__(self):
self.batch_memory_usages = []

def epoch_start(self, state: State, logger: Logger) -> None:
current_alloc_memory = torch.cuda.memory_allocated() // 2**20 # Convert to MiB
self.batch_memory_usages.append(current_alloc_memory)

def after_dataloader(self, state: State, logger: Logger):
current_alloc_memory = torch.cuda.memory_allocated() // 2**20 # Convert to MiB
self.batch_memory_usages.append(current_alloc_memory)

microbatch_size = 1
input_shape = (100000,)
dataset = RandomClassificationDataset(shape=input_shape, size=1024)
train_dataloader = DataLoader(dataset, batch_size=global_batch_size)
mini_memory_monitor = MiniMemoryMonitor()

trainer = Trainer(
model=model,
train_dataloader=train_dataloader,
max_duration='1ba',
device='gpu',
device_train_microbatch_size=microbatch_size,
callbacks=[mini_memory_monitor],
)

trainer.fit()
return mini_memory_monitor.batch_memory_usages[1] - mini_memory_monitor.batch_memory_usages[0]

global_batch_size = 32
mem_change_epoch_start_and_after_dataloader = track_memory_after_dataloader(global_batch_size)
assert (mem_change_epoch_start_and_after_dataloader < 1), (
f'Memory increased between epoch start and after dataloader by more than 1 MiB: {mem_change_epoch_start_and_after_dataloader} MiB. '
f'None of the samples should be moved onto a GPU until the batch has already been divided into microbatches.'
)


def _assert_optimizer_is_on_device(optimizer: torch.optim.Optimizer):
for state in optimizer.state.values():
Expand Down

0 comments on commit 62b70f3

Please sign in to comment.