Skip to content

Commit

Permalink
Allowing for fractional number of samples per rank (mosaicml#3075)
Browse files Browse the repository at this point in the history
* made chages to allow for fractional num_samples per rank

* Minor typo fix

* fixing skip metric condition in eval loop

* Changes to allow specifying fractional microbatches

* Fixed todos related to autobatching and reducing microbatch size for the last batch when drop last is false

* added some comments

* fixing pyright issues

* raise error if dtms > 1 in seq parallelism

* added test for _accumulate_time_across_ranks

* fixed the test

* added error threshold

* fixing pyright errors

* changed spg_ws to seq_parallel_world_size

* lint

* lint

* lint

* lint

* lint

* lint
  • Loading branch information
ShashankMosaicML authored Mar 8, 2024
1 parent ba16b1d commit c5869d2
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 11 deletions.
4 changes: 2 additions & 2 deletions composer/core/data_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ class DataSpec:
the ``dataloader`` yields batches not of type :class:`torch.Tensor`, Mapping, Tuple, or List, then
this function must be specified.
get_num_samples_in_batch ((Batch) -> int, optional): Function that is called by the :class:`.Trainer`
get_num_samples_in_batch ((Batch) -> Union[int, float], optional): Function that is called by the :class:`.Trainer`
to get the number of samples in the provided batch.
By default, if the batch contains tensors that all have the same 0th dim, then the value of the 0th dim will
Expand All @@ -178,7 +178,7 @@ def __init__(
num_tokens: Optional[int] = None,
device_transforms: Optional[Callable[[Batch], Batch]] = None,
split_batch: Optional[Callable[[Batch, int], Sequence[Batch]]] = None,
get_num_samples_in_batch: Optional[Callable[[Batch], int]] = None,
get_num_samples_in_batch: Optional[Callable[[Batch], Union[int, float]]] = None,
get_num_tokens_in_batch: Optional[Callable[[Batch], int]] = None,
) -> None:
self.dataloader: Union[Iterable, torch.utils.data.DataLoader] = dataloader
Expand Down
83 changes: 74 additions & 9 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def _get_initial_device_train_microbatch_size(
"`device_train_microbatch_size='auto'` requires the `state.train_dataloader` to have a `batch_size` attribute.",
) from e
return batch_size
elif isinstance(device_train_microbatch_size, int):
elif isinstance(device_train_microbatch_size, Union[int, float]):
return device_train_microbatch_size
else:
raise ValueError("device_train_microbatch_size must be an int or ``'auto'``")
Expand Down Expand Up @@ -1087,6 +1087,16 @@ def __init__(

# Microbatching
auto_microbatching = _is_auto_microbatching(device_train_microbatch_size, device=device)
if auto_microbatching and train_dataloader is not None and hasattr(train_dataloader, 'seq_parallel_world_size'):
raise ValueError('`device_train_microbatch_size="auto"` is not compatible with sequence parallelism.')
if isinstance(device_train_microbatch_size, int) and train_dataloader is not None and hasattr(
train_dataloader,
'seq_parallel_world_size',
) and device_train_microbatch_size * train_dataloader.seq_parallel_world_size != 1: # type: ignore
raise ValueError(
'`Sequence parallelism requires a microbatch size of 1 distributed over the sequence parallel group.',
)

if auto_microbatching and profiler:
raise ValueError(
"`device_train_microbatch_size='auto'` is not compatible with the profiler. It is "
Expand Down Expand Up @@ -1419,6 +1429,15 @@ def __init__(

for evaluator in evaluators:
validate_eval_automicrobatching(evaluator.auto_microbatching, self.state.device)
if evaluator.auto_microbatching and hasattr(evaluator.dataloader, 'seq_parallel_world_size'):
raise ValueError('`validate_eval_automicrobatching` is not compatible with sequence parallelism.')
if isinstance(evaluator.dataloader.get_num_samples_in_batch, int) and hasattr(
evaluator.dataloader,
'seq_parallel_world_size',
) and evaluator.dataloader.get_num_samples_in_batch * evaluator.dataloader.seq_parallel_world_size != 1: # type: ignore
raise ValueError(
'`Sequence parallelism requires a microbatch size of 1 distributed over the sequence parallel group.',
)
if len(evaluators) == 0:
if eval_subset_num_batches != -1:
raise ValueError(
Expand Down Expand Up @@ -2098,6 +2117,15 @@ def fit(

for evaluator in evaluators:
validate_eval_automicrobatching(evaluator.auto_microbatching, self.state.device)
if evaluator.auto_microbatching and hasattr(evaluator.dataloader, 'seq_parallel_world_size'):
raise ValueError('`validate_eval_automicrobatching` is not compatible with sequence parallelism.')
if isinstance(evaluator.dataloader.get_num_samples_in_batch, int) and hasattr(
evaluator.dataloader,
'seq_parallel_world_size',
) and evaluator.dataloader.get_num_samples_in_batch * evaluator.dataloader.seq_parallel_world_size != 1: # type: ignore
raise ValueError(
'`Sequence parallelism requires a microbatch size of 1 distributed over the sequence parallel group.',
)

if len(evaluators) == 0:
if eval_subset_num_batches != -1:
Expand All @@ -2113,6 +2141,18 @@ def fit(
device_train_microbatch_size,
device=self.state.device,
)
if self.state.auto_microbatching and self._train_data_spec is not None and hasattr(
self._train_data_spec,
'seq_parallel_world_size',
):
raise ValueError('`device_train_microbatch_size="auto"` is not compatible with sequence parallelism.')
if isinstance(device_train_microbatch_size, int) and train_dataloader is not None and hasattr(
train_dataloader,
'seq_parallel_world_size',
) and device_train_microbatch_size * train_dataloader.seq_parallel_world_size != 1: # type: ignore
raise ValueError(
'`Sequence parallelism requires a microbatch size of 1 distributed over the sequence parallel group.',
)
if self.state.auto_microbatching and self.state.profiler:
raise ValueError(
"`device_train_microbatch_size='auto'` is not compatible with the profiler. It is "
Expand Down Expand Up @@ -2226,7 +2266,7 @@ def _spin_dataloaders_to_cur_epoch(self):

def _accumulate_time_across_ranks(
self,
num_samples: int,
num_samples: Union[int, float],
num_tokens: int,
batch_time: datetime.timedelta,
) -> Tuple[int, int, datetime.timedelta]:
Expand All @@ -2236,10 +2276,22 @@ def _accumulate_time_across_ranks(
"""
# Samples and tokens should be summed
# Batch time should be the value from rank 0
sample_token_tensor = self.state.device.tensor_to_device(
torch.tensor([num_samples, num_tokens], dtype=torch.int),
)

# num_samples can be floating point if we are doing sequence parallelism, since in that case each rank works on only a part of the sample. For example, with sequence parallelism world size 2, each rank trains on half of a sample.
if isinstance(num_samples, float):
sample_token_tensor = self.state.device.tensor_to_device(
torch.tensor([num_samples, num_tokens], dtype=torch.float32),
)
else:
sample_token_tensor = self.state.device.tensor_to_device(
torch.tensor([num_samples, num_tokens], dtype=torch.int),
)
dist.all_reduce(sample_token_tensor, reduce_operation='SUM')
if isinstance(num_samples, float):
sample_token_tensor_int = sample_token_tensor.to(torch.int)
if torch.any(torch.abs(sample_token_tensor_int - sample_token_tensor) > 1e-4):
raise ValueError('The sums of samples and tokens across ranks should each be integers.')
sample_token_tensor = sample_token_tensor_int
batch_time_tensor = self.state.device.tensor_to_device(
torch.tensor([batch_time.total_seconds()], dtype=torch.float32),
)
Expand Down Expand Up @@ -2657,14 +2709,14 @@ def _train_microbatches(
def _train_microbatch(
self,
use_grad_scaling: bool,
current_batch_size: int,
current_batch_size: Union[int, float],
is_final_microbatch: bool,
) -> Dict[str, torch.Tensor]:
"""Train and compute the loss of ``state.batch``, which is assumed to be a single microbatch.
Args:
use_grad_scaling (bool): Whether to use gradient scaling.
current_batch_size (int): The current batch size.
current_batch_size (int, float): The current batch size.
minibatch_num_samples (int): Number of samples in the minibatch.
is_final_microbatch (bool): If current microbatch is the last one.
"""
Expand Down Expand Up @@ -3079,6 +3131,15 @@ def eval(

for evaluator in evaluators:
validate_eval_automicrobatching(evaluator.auto_microbatching, self.state.device)
if evaluator.auto_microbatching and hasattr(evaluator.dataloader, 'seq_parallel_world_size'):
raise ValueError('`validate_eval_automicrobatching` is not compatible with sequence parallelism.')
if isinstance(evaluator.dataloader.get_num_samples_in_batch, int) and hasattr(
evaluator.dataloader,
'seq_parallel_world_size',
) and evaluator.dataloader.get_num_samples_in_batch * evaluator.dataloader.seq_parallel_world_size != 1: # type: ignore
raise ValueError(
'`Sequence parallelism requires a microbatch size of 1 distributed over the sequence parallel group.',
)

self.state.evaluators.extend(evaluators) # Add evaluators to state.evaluators
else:
Expand Down Expand Up @@ -3179,7 +3240,9 @@ def _eval_loop(
if dist_sampler is not None and drop_last == False and dataset_len is not None:
batch_num_samples_tensor = self.state.device.tensor_to_device(torch.tensor(rank_num_samples))
dist.all_reduce(batch_num_samples_tensor, reduce_operation='SUM')
batch_num_samples = batch_num_samples_tensor.item()
batch_num_samples = int(batch_num_samples_tensor.item())
if abs(batch_num_samples - batch_num_samples_tensor.item()) > 1e-4:
raise ValueError('Number of samples in a batch should be an integer.')
last_batch = self.state.eval_timestamp.sample + batch_num_samples >= dataset_len

if self.state.deepspeed_enabled:
Expand All @@ -3206,10 +3269,12 @@ def _eval_loop(
rank_num_samples -= 1
num_samples_in_microbatch = data_spec.get_num_samples_in_batch(self.state.batch)
# Skip updating metric if batch is only padded samples
if num_samples_in_microbatch == 1:
if num_samples_in_microbatch == 1 or hasattr(data_spec, 'seq_parallel_world_size'):
skip_metric_update = True
# Remove padded samples from batch
else:
if not isinstance(num_samples_in_microbatch, int):
raise ValueError('Number of samples in a batch should be an integer.')
self.state.batch = data_spec.split_batch(
self.state.batch,
num_samples_in_microbatch - 1,
Expand Down
34 changes: 34 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1212,6 +1212,40 @@ def test_iteration(
assert trainer.state.timestamp.epoch == Time(5, TimeUnit.EPOCH)
assert trainer.state.timestamp.iteration == Time(2, TimeUnit.ITERATION)

@pytest.mark.gpu
@pytest.mark.world_size(2)
@pytest.mark.parametrize('num_samples', [2, 0.5])
def test_accumulate_time_across_ranks(
self,
train_dataloader: DataLoader,
model: ComposerModel,
max_duration: Time[int],
num_samples: Union[int, float],
):
# Train once with the max_duration param on Trainer.__init__()
init_trainer = Trainer(
model=model,
max_duration=max_duration,
train_dataloader=train_dataloader,
)

num_tokens = 10
batch_time = datetime.timedelta(seconds=0.1 * (1 + dist.get_global_rank()))

num_samples_accum, num_tokens_accum, batch_time_accum = init_trainer._accumulate_time_across_ranks(
num_samples,
num_tokens,
batch_time,
)

assert isinstance(num_tokens_accum, int)
assert isinstance(num_samples_accum, int)
assert isinstance(batch_time_accum, datetime.timedelta)

assert num_samples_accum == num_samples * 2
assert num_tokens_accum == num_tokens * 2
assert batch_time_accum == datetime.timedelta(seconds=0.1 * (1 + 0))


@world_size(1, 2)
@device('cpu', 'gpu', 'gpu-amp', precision=True)
Expand Down

0 comments on commit c5869d2

Please sign in to comment.